diff --git a/.gitignore b/.gitignore index e8ba8747559abdacd563868545e2a6666be82a3c..3bb6d6cbcf2906a99d3302595285d50145395fe5 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ CMakeCache.txt GPATH GRTAGS GTAGS +/.clangd/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1d9b64f0f08f80b6809fe6860e17b7c24444351b..02c2d48ced6eb3f1e1c106f5d00e010b258fee3c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,5 +1,4 @@ stages: - - build - test - coverage @@ -9,7 +8,6 @@ variables: include: - local: '/.gitlab-ci/clang10-mpi-debug.yml' - local: '/.gitlab-ci/clang10-mpi-release.yml' - - local: '/.gitlab-ci/gcc10-mpi-debug.yml' - local: '/.gitlab-ci/gcc10-mpi-coverage.yml' - local: '/.gitlab-ci/gcc10-mpi-release.yml' - local: '/.gitlab-ci/gcc10-seq-coverage.yml' diff --git a/.gitlab-ci/clang10-mpi-debug.yml b/.gitlab-ci/clang10-mpi-debug.yml index 116a3f1034cd679352569680a09d878b31b83c73..ceafce232dcaf5935f1c6be403c796364a3b8c16 100644 --- a/.gitlab-ci/clang10-mpi-debug.yml +++ b/.gitlab-ci/clang10-mpi-debug.yml @@ -1,27 +1,13 @@ -build:clang10-mpi-debug: - image: localhost:5000/ubuntu_clang10_mpi - stage: build - needs: [] - script: - - mkdir -p build/clang10-debug-mpi - - cd build/clang10-debug-mpi - - CXX=clang++-10 CC=clang-10 cmake ../.. -DCMAKE_BUILD_TYPE=Debug - - make pugs - cache: - key: "${CI_COMMIT_REF_SLUG}-clang10-debug-mpi" - paths: - - build/clang10-debug-mpi - untracked: true - test:clang10-mpi-debug: image: localhost:5000/ubuntu_clang10_mpi stage: test - needs: ["build:clang10-mpi-debug"] + needs: [] script: - mkdir -p build/clang10-debug-mpi - cd build/clang10-debug-mpi - CXX=clang++-10 CC=clang-10 cmake ../.. -DCMAKE_BUILD_TYPE=Debug - - make run_unit_tests + - make + - make check cache: key: "${CI_COMMIT_REF_SLUG}-clang10-debug-mpi" paths: diff --git a/.gitlab-ci/clang10-mpi-release.yml b/.gitlab-ci/clang10-mpi-release.yml index cb1ffbda2ddc3bbed30f03d2caa4f0c8e45629c6..58a2d2d2e3878ffd8672eb71e15b34a5d58062a6 100644 --- a/.gitlab-ci/clang10-mpi-release.yml +++ b/.gitlab-ci/clang10-mpi-release.yml @@ -1,27 +1,13 @@ -build:clang10-mpi-release: - image: localhost:5000/ubuntu_clang10_mpi - stage: build - needs: [] - script: - - mkdir -p build/clang10-release-mpi - - cd build/clang10-release-mpi - - CXX=clang++-10 CC=clang-10 cmake ../.. -DCMAKE_BUILD_TYPE=Release -DCLANG_FORMAT=/usr/bin/clang-format-10 - - make pugs - cache: - key: "${CI_COMMIT_REF_SLUG}-clang10-release-mpi" - paths: - - build/clang10-release-mpi - untracked: true - test:clang10-mpi-release: image: localhost:5000/ubuntu_clang10_mpi stage: test - needs: ["build:clang10-mpi-release"] + needs: [] script: - mkdir -p build/clang10-release-mpi - cd build/clang10-release-mpi - CXX=clang++-10 CC=clang-10 cmake ../.. -DCMAKE_BUILD_TYPE=Release -DCLANG_FORMAT=/usr/bin/clang-format-10 - - make run_unit_tests + - make + - make check cache: key: "${CI_COMMIT_REF_SLUG}-clang10-release-mpi" paths: diff --git a/.gitlab-ci/gcc10-mpi-coverage.yml b/.gitlab-ci/gcc10-mpi-coverage.yml index 156592e9f4c080b545915a5e9a7aecef96ce3d2e..d2cc109d71cdada957ae5ce65c72ede6ce3b21ea 100644 --- a/.gitlab-ci/gcc10-mpi-coverage.yml +++ b/.gitlab-ci/gcc10-mpi-coverage.yml @@ -1,12 +1,12 @@ coverage:gcc10-mpi-coverage: image: localhost:5000/ubuntu_gcc10_mpi stage: coverage + needs: [] script: - mkdir -p build/gcc10-cov-mpi - cd build/gcc10-cov-mpi - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Coverage - - make pugs - - make coverage + - make cache: key: "${CI_COMMIT_REF_SLUG}-gcc10-cov-mpi" paths: diff --git a/.gitlab-ci/gcc10-mpi-debug.yml b/.gitlab-ci/gcc10-mpi-debug.yml deleted file mode 100644 index 6dd2806678651be82e4e2defdbb0c57ea2597f79..0000000000000000000000000000000000000000 --- a/.gitlab-ci/gcc10-mpi-debug.yml +++ /dev/null @@ -1,29 +0,0 @@ -build:gcc10-mpi-debug: - image: localhost:5000/ubuntu_gcc10_mpi - stage: build - needs: [] - script: - - mkdir -p build/gcc10-debug-mpi - - cd build/gcc10-debug-mpi - - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Debug - - make pugs - cache: - key: "${CI_COMMIT_REF_SLUG}-gcc10-debug-mpi" - paths: - - build/gcc10-debug-mpi - untracked: true - -test:gcc10-mpi-debug: - image: localhost:5000/ubuntu_gcc10_mpi - stage: test - needs: ["build:gcc10-mpi-debug"] - script: - - mkdir -p build/gcc10-debug-mpi - - cd build/gcc10-debug-mpi - - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Debug - - make run_unit_tests - cache: - key: "${CI_COMMIT_REF_SLUG}-gcc10-debug-mpi" - paths: - - build/gcc10-debug-mpi - untracked: true diff --git a/.gitlab-ci/gcc10-mpi-release.yml b/.gitlab-ci/gcc10-mpi-release.yml index 3f7d2a6a9aae7368216047a5fb691dc71df2aff9..5e8277ad61a089d8baa72e5cc7e8d2b7219ad775 100644 --- a/.gitlab-ci/gcc10-mpi-release.yml +++ b/.gitlab-ci/gcc10-mpi-release.yml @@ -1,27 +1,13 @@ -build:gcc10-mpi-release: - image: localhost:5000/ubuntu_gcc10_mpi - stage: build - needs: [] - script: - - mkdir -p build/gcc10-release-mpi - - cd build/gcc10-release-mpi - - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Release - - make pugs - cache: - key: "${CI_COMMIT_REF_SLUG}-gcc10-release-mpi" - paths: - - build/gcc10-release-mpi - untracked: true - test:gcc10-mpi-release: image: localhost:5000/ubuntu_gcc10_mpi stage: test - needs: ["build:gcc10-mpi-release"] + needs: [] script: - mkdir -p build/gcc10-release-mpi - cd build/gcc10-release-mpi - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Release - - make run_unit_tests + - make + - make test cache: key: "${CI_COMMIT_REF_SLUG}-gcc10-release-mpi" paths: diff --git a/.gitlab-ci/gcc10-seq-coverage.yml b/.gitlab-ci/gcc10-seq-coverage.yml index 663798274b0c465892c78084dde56022b151b1ba..5cacacd5d02a0838c36b852076c8cb6ad0573031 100644 --- a/.gitlab-ci/gcc10-seq-coverage.yml +++ b/.gitlab-ci/gcc10-seq-coverage.yml @@ -1,12 +1,12 @@ coverage:gcc10-seq-coverage: image: localhost:5000/ubuntu_gcc10 stage: coverage + needs: [] script: - mkdir -p build/gcc10-cov - cd build/gcc10-cov - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Coverage - - make pugs - - make coverage + - make cache: key: "${CI_COMMIT_REF_SLUG}-gcc10-cov" paths: diff --git a/.gitlab-ci/gcc10-seq-release.yml b/.gitlab-ci/gcc10-seq-release.yml index 525a08a80f295df5a8e766304279b79cd3fef4b0..b4423bbd31f947698a5fa4dc13171e0582fc5d32 100644 --- a/.gitlab-ci/gcc10-seq-release.yml +++ b/.gitlab-ci/gcc10-seq-release.yml @@ -1,27 +1,13 @@ -build:gcc10-seq-release: - image: localhost:5000/ubuntu_gcc10 - stage: build - needs: [] - script: - - mkdir -p build/gcc10-release-seq - - cd build/gcc10-release-seq - - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Release - - make pugs - cache: - key: "${CI_COMMIT_REF_SLUG}-gcc10-release-seq" - paths: - - build/gcc10-release-seq - untracked: true - test:gcc10-seq-release: image: localhost:5000/ubuntu_gcc10 stage: test - needs: ["build:gcc10-seq-release"] + needs: [] script: - mkdir -p build/gcc10-release-seq - cd build/gcc10-release-seq - CXX=g++-10 CC=gcc-10 cmake ../.. -DCMAKE_BUILD_TYPE=Release - - make run_unit_tests + - make + - make check cache: key: "${CI_COMMIT_REF_SLUG}-gcc10-release-seq" paths: diff --git a/CMakeLists.txt b/CMakeLists.txt index 93e29a96ff0090554ba3021094a2b61e3a660dc8..696c96578a15c9303d9135ec8a56089062c2cc30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,9 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") # Forbids in-source builds include(CheckNotInSources) +# use PkgConfig to find packages +find_package(PkgConfig REQUIRED) + #------------------------------------------------------ #----------------- Main configuration ----------------- #------------------------------------------------------ @@ -78,6 +81,12 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() +#------------------------------------------------------ +# default build shared libraries +if (NOT BUILD_SHARED_LIBS) + set(BUILD_SHARED_LIBS ON CACHE STRING "" FORCE) +endif() + #------------------------------------------------------ # Checks if compiler version is compatible with Pugs sources @@ -113,13 +122,12 @@ if (NOT PUGS_ENABLE_MPI MATCHES "^(AUTO|ON|OFF)$") message(FATAL_ERROR "PUGS_ENABLE_MPI='${PUGS_ENABLE_MPI}'. Must be set to one of AUTO, ON or OFF") endif() -# checks for MPI +# Check for MPI if (PUGS_ENABLE_MPI MATCHES "^(AUTO|ON)$") set(MPI_DETERMINE_LIBRARY_VERSION TRUE) find_package(MPI) endif() - #------------------------------------------------------ # Search for ParMETIS @@ -135,6 +143,38 @@ if(${MPI_FOUND}) message(FATAL_ERROR "MPI support requires ParMETIS which cannot be found!") endif() endif() +else() + if(PUGS_ENABLE_MPI MATCHES "^ON$") + message(FATAL_ERROR "Cannot find MPI!") + endif() +endif() + +#------------------------------------------------------ +# Check for PETSc +# defaults use PETSc +set(PUGS_ENABLE_PETSC AUTO CACHE STRING + "Choose one of: AUTO ON OFF") + +if (PUGS_ENABLE_PETSC MATCHES "^(AUTO|ON)$") + if (MPI_FOUND) + # PETSc support is deactivated if MPI is not found + pkg_check_modules(PETSC PETSc) + else() + message(STATUS "PETSc support is deactivated since pugs will not be build with MPI support") + set(PETSC_FOUND FALSE) + unset(PUGS_HAS_PETSC) + endif() + set(PUGS_HAS_PETSC ${PETSC_FOUND}) +else() + unset(PUGS_HAS_PETSC) +endif() + +if (${PETSC_FOUND}) + include_directories(SYSTEM ${PETSC_INCLUDE_DIRS}) +else() + if (PUGS_ENABLE_PETSC MATCHES "^ON$") + message(FATAL_ERROR "Could not find PETSc!") + endif() endif() # ----------------------------------------------------- @@ -279,12 +319,50 @@ include("${CATCH_MODULE_PATH}/contrib/ParseAndAddCatchTests.cmake") add_subdirectory("${CATCH_MODULE_PATH}") add_subdirectory(tests) -enable_testing() -add_custom_target(run_unit_tests - COMMAND ${CMAKE_CTEST_COMMAND} +if(${PUGS_HAS_MPI}) + set(MPIEXEC_OPTION_FLAGS --oversubscribe) + if (NOT "$ENV{GITLAB_CI}" STREQUAL "") + set(MPIEXEC_OPTION_FLAGS ${MPIEXEC_OPTION_FLAGS} --allow-run-as-root) + endif() + set(MPIEXEC_NUMPROC 3) + set(MPIEXEC_PATH_FLAG --path) + set(MPI_LAUNCHER ${MPIEXEC} ${MPIEXEC_NUMPROC_FLAG} ${MPIEXEC_NUMPROC} ${MPIEXEC_PATH_FLAG} ${PUGS_BINARY_DIR} ${MPIEXEC_OPTION_FLAGS}) +endif() + +add_custom_target(all_unit_tests DEPENDS unit_tests mpi_unit_tests - COMMENT "Executing unit tests." +) + +add_custom_target(check + DEPENDS test + ) + +add_custom_target(test + DEPENDS run_all_unit_tests + ) + +add_custom_target(run_all_unit_tests + DEPENDS run_mpi_unit_tests + ) + +if(PUGS_HAS_MPI) + set(RUN_MPI_UNIT_TESTS_COMMENT "Running mpi_unit_tests [using ${MPIEXEC_NUMPROC} procs]") +else() + set(RUN_MPI_UNIT_TESTS_COMMENT "Running mpi_unit_tests [sequentially]") +endif() + +add_custom_target(run_mpi_unit_tests + COMMAND ${MPI_LAUNCHER} "${PUGS_BINARY_DIR}/mpi_unit_tests" + DEPENDS run_unit_tests + COMMENT ${RUN_MPI_UNIT_TESTS_COMMENT} + ) + + +add_custom_target(run_unit_tests + COMMAND "${PUGS_BINARY_DIR}/unit_tests" + DEPENDS all_unit_tests + COMMENT "Running unit_tests" ) # unit tests coverage @@ -330,40 +408,69 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Coverage") find_program(FASTCOV fastcov fastcov.py) + add_custom_target(coverage_unit_tests + ALL # in coverage mode we do coverage! + + COMMAND "${PUGS_BINARY_DIR}/unit_tests" + + COMMENT "Running unit_tests" + DEPENDS coverage_zero_counters + ) + + add_custom_target(coverage_mpi_unit_tests + ALL # in coverage mode we do coverage! + + COMMAND ${MPI_LAUNCHER} "mpi_unit_tests" + + COMMENT "Running mpi_unit_tests" + DEPENDS coverage_unit_tests + ) + + add_custom_target(coverage_run_all_unit_tests + ALL # in coverage mode we do coverage! + + DEPENDS coverage_mpi_unit_tests + ) + if (FASTCOV AND (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "9")) - add_custom_target(coverage + add_custom_target(coverage_zero_counters ALL # in coverage mode we do coverage! # zero all counters - COMMAND ${FASTCOV} -q -z + COMMAND ${FASTCOV} -q -z --gcov "${GCOV_BIN}" + DEPENDS all_unit_tests + ) - # Run tests - COMMAND ${CMAKE_CTEST_COMMAND} + add_custom_target(coverage + ALL # in coverage mode we do coverage! - COMMAND ${FASTCOV} -q --gcov "${GCOV_BIN}" + COMMAND ${FASTCOV} --gcov "${GCOV_BIN}" --include "${PUGS_SOURCE_DIR}/src" - --exclude "${PUGS_SOURCE_DIR}/src/main.cpp" "${PUGS_SOURCE_DIR}/src/utils/BacktraceManager.*" + --exclude "${PUGS_SOURCE_DIR}/src/main.cpp" "${PUGS_SOURCE_DIR}/src/utils/BacktraceManager.*" "${PUGS_SOURCE_DIR}/src/utils/FPEManager.*" "${PUGS_SOURCE_DIR}/src/utils/SignalManager.*" --lcov -o coverage.info -n COMMAND ${LCOV} --gcov "${GCOV_BIN}" --list coverage.info - DEPENDS unit_tests mpi_unit_tests + DEPENDS coverage_run_all_unit_tests COMMENT "Running test coverage." WORKING_DIRECTORY "${PUGS_BINARY_DIR}" ) else() - add_custom_target(coverage + add_custom_target(coverage_zero_counters ALL # in coverage mode we do coverage! # Cleanup previously generated profiling data COMMAND ${LCOV} -q --gcov "${GCOV_BIN}" --base-directory "${PUGS_BINARY_DIR}/src" --directory "${PUGS_BINARY_DIR}" --zerocounters # Initialize profiling data with zero coverage for every instrumented line of the project # This way the percentage of total lines covered will always be correct, even when not all source code files were loaded during the test(s) COMMAND ${LCOV} -q --gcov "${GCOV_BIN}" --base-directory "${PUGS_BINARY_DIR}/src" --directory "${PUGS_BINARY_DIR}" --capture --initial --output-file coverage_base.info - # Run tests - COMMAND ${CMAKE_CTEST_COMMAND} + DEPENDS all_unit_tests + ) + + + add_custom_target(coverage # Collect data from executions COMMAND ${LCOV} --gcov "${GCOV_BIN}" --base-directory "${PUGS_BINARY_DIR}/src" --directory "${PUGS_BINARY_DIR}" --capture --output-file coverage_ctest.info # Combine base and ctest results @@ -378,7 +485,7 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Coverage") COMMAND ${LCOV} --gcov "${GCOV_BIN}" --list coverage.info - DEPENDS unit_tests mpi_unit_tests + DEPENDS coverage_run_all_unit_tests COMMENT "Running test coverage." WORKING_DIRECTORY "${PUGS_BINARY_DIR}" ) @@ -414,15 +521,18 @@ add_executable( target_link_libraries( pugs PugsMesh + PugsAlgebra PugsUtils PugsLanguage PugsLanguageAST PugsLanguageModules PugsLanguageAlgorithms PugsMesh + PugsAlgebra PugsUtils PugsLanguageUtils kokkos + ${PETSC_LIBRARIES} ${PARMETIS_LIBRARIES} ${MPI_CXX_LINK_FLAGS} ${MPI_CXX_LIBRARIES} ${KOKKOS_CXX_FLAGS} @@ -450,37 +560,51 @@ endif() message("") message("====== pugs build options ======") -message(STATUS "version: ${PUGS_VERSION}") -message(STATUS "build type: ${CMAKE_BUILD_TYPE}") -message(STATUS "compiler: ${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") -message(STATUS "kokkos devices: ${PUGS_BUILD_KOKKOS_DEVICES}") +message(" version: ${PUGS_VERSION}") +message(" build type: ${CMAKE_BUILD_TYPE}") +message(" compiler: ${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") +message(" kokkos devices: ${PUGS_BUILD_KOKKOS_DEVICES}") if (MPI_FOUND) - message(STATUS "MPI: ${MPI_CXX_LIBRARY_VERSION_STRING}") + message(" MPI: ${MPI_CXX_LIBRARY_VERSION_STRING}") else() if(NOT PARMETIS_LIBRARIES) - message(STATUS "MPI: deactivated: ParMETIS cannot be found!") + message(" MPI: deactivated: ParMETIS cannot be found!") + else() + if (PUGS_ENABLE_MPI MATCHES "^(AUTO|ON)$") + message(" MPI: not found!") + else() + message(" MPI: explicitly deactivated!") + endif() + endif() +endif() + +if (PETSC_FOUND) + message(" PETSc: ${PETSC_VERSION}") +else() + if (PUGS_ENABLE_PETSC MATCHES "^(AUTO|ON)$") + message(" PETSc: not found!") else() - message(STATUS "MPI: not found!") + message(" PETSc: explicitly deactivated!") endif() endif() if(CLANG_FORMAT) - message(STATUS "clang-format: ${CLANG_FORMAT}") + message(" clang-format: ${CLANG_FORMAT}") else() - message(STATUS "clang-format: not found!") + message(" clang-format: not found!") endif() if(CLAZY_STANDALONE) - message(STATUS "clazy-standalone: ${CLAZY_STANDALONE}") + message(" clazy-standalone: ${CLAZY_STANDALONE}") else() - message(STATUS "clazy-standalone: no found!") + message(" clazy-standalone: no found!") endif() if (DOXYGEN_FOUND) - message(STATUS "doxygen: ${DOXYGEN_EXECUTABLE}") + message(" doxygen: ${DOXYGEN_EXECUTABLE}") else() - message(STATUS "doxygen: no found!") + message(" doxygen: no found!") endif() message("================================") diff --git "a/packages/PEGTL/src/test/pegtl/file_\303\244\303\266\303\274\360\235\204\236_data.txt" "b/packages/PEGTL/src/test/pegtl/file_\303\244\303\266\303\274\360\235\204\236_data.txt" deleted file mode 100644 index d1c7bba09c907f77a6eb263e90b8a5d5b7873a7e..0000000000000000000000000000000000000000 --- "a/packages/PEGTL/src/test/pegtl/file_\303\244\303\266\303\274\360\235\204\236_data.txt" +++ /dev/null @@ -1,11 +0,0 @@ -dummy content -dummy content -dummy content -dummy content -dummy content -dummy content -dummy content -dummy content -dummy content -dummy content -dummy content diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0d6e8cbb996d4f6c18c9098bb9af7051c11b0dbf..1b21f79a7545c442be5448792132acbdf840947a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,7 +10,7 @@ add_subdirectory(utils) add_subdirectory(language) # Pugs algebra -#add_subdirectory(algebra) +add_subdirectory(algebra) # Pugs mesh add_subdirectory(mesh) diff --git a/src/algebra/BiCGStab.hpp b/src/algebra/BiCGStab.hpp index bc822668a7f63c4d5d7cc4c45f4d1b1d70084fc0..ff34b89e148d844215cf7fffc21c2ad78bca5ec9 100644 --- a/src/algebra/BiCGStab.hpp +++ b/src/algebra/BiCGStab.hpp @@ -7,16 +7,20 @@ #include <rang.hpp> -template <bool verbose = true> struct BiCGStab { template <typename VectorType, typename MatrixType> - BiCGStab(const VectorType& b, const MatrixType& A, VectorType& x, const size_t max_iter, const double epsilon = 1e-6) + BiCGStab(const MatrixType& A, + VectorType& x, + const VectorType& b, + const double epsilon, + const size_t maximum_iteration, + const bool verbose) { - if constexpr (verbose) { + if (verbose) { std::cout << "- bi-conjugate gradient stabilized\n"; std::cout << " epsilon = " << epsilon << '\n'; - std::cout << " maximum number of iterations: " << max_iter << '\n'; + std::cout << " maximum number of iterations: " << maximum_iteration << '\n'; } VectorType r_k_1{b.size()}; @@ -38,13 +42,13 @@ struct BiCGStab VectorType r_k{x.size()}; - if constexpr (verbose) { + if (verbose) { std::cout << " initial residu: " << resid0 << '\n'; } - for (size_t i = 1; i <= max_iter; ++i) { - if constexpr (verbose) { - std::cout << " - iteration: " << std::setw(6) << i << "\tresidu: " << residu / resid0 - << "\tabsolute: " << residu << '\n'; + for (size_t i = 1; i <= maximum_iteration; ++i) { + if (verbose) { + std::cout << " - iteration: " << std::setw(6) << i << " residu: " << std::scientific << residu / resid0 + << " absolute: " << std::scientific << residu << '\n'; } Ap_k = A * p_k; @@ -77,8 +81,8 @@ struct BiCGStab << '\n'; ; std::cout << " - epsilon: " << epsilon << '\n'; - std::cout << " - relative residu : " << residu / resid0 << '\n'; - std::cout << " - absolute residu : " << residu << '\n'; + std::cout << " - relative residu : " << std::scientific << residu / resid0 << '\n'; + std::cout << " - absolute residu : " << std::scientific << residu << '\n'; } } } diff --git a/src/algebra/PCG.hpp b/src/algebra/CG.hpp similarity index 60% rename from src/algebra/PCG.hpp rename to src/algebra/CG.hpp index 36a4635b313d2bba13d97e1b7a4ea1882e2a66a2..538c77571dac4d551d7476ea79bf855e8026fb53 100644 --- a/src/algebra/PCG.hpp +++ b/src/algebra/CG.hpp @@ -6,30 +6,30 @@ #include <rang.hpp> -template <bool verbose = true> -struct PCG +struct CG { - template <typename VectorType, typename MatrixType, typename PreconditionerType> - PCG(const VectorType& f, - const MatrixType& A, - [[maybe_unused]] const PreconditionerType& C, - VectorType& x, - const size_t maxiter, - const double epsilon = 1e-6) + template <typename VectorType, typename MatrixType> + CG(const MatrixType& A, + VectorType& x, + const VectorType& f, + const double epsilon, + const size_t maximum_iteration, + const bool verbose = false) { - if constexpr (verbose) { + if (verbose) { std::cout << "- conjugate gradient\n"; std::cout << " epsilon = " << epsilon << '\n'; - std::cout << " maximum number of iterations: " << maxiter << '\n'; + std::cout << " maximum number of iterations: " << maximum_iteration << '\n'; } VectorType h{f.size()}; VectorType b = copy(f); - if constexpr (verbose) { + if (verbose) { h = A * x; h -= f; - std::cout << "- initial *real* residu : " << (h, h) << '\n'; + std::cout << "- initial " << rang::style::bold << "real" << rang::style::reset << " residu : " << (h, h) + << '\n'; } VectorType g{b.size()}; @@ -40,7 +40,7 @@ struct PCG double relativeEpsilon = epsilon; - for (size_t i = 1; i <= maxiter; ++i) { + for (size_t i = 1; i <= maximum_iteration; ++i) { if (i == 1) { h = A * x; @@ -74,13 +74,13 @@ struct PCG if ((i == 1) && (gcg != 0)) { relativeEpsilon = epsilon * gcg; gcg0 = gcg; - if constexpr (verbose) { + if (verbose) { std::cout << " initial residu: " << gcg << '\n'; } } - if constexpr (verbose) { - std::cout << " - iteration " << std::setw(6) << i << "\tresidu: " << gcg / gcg0; - std::cout << "\tabsolute: " << gcg << '\n'; + if (verbose) { + std::cout << " - iteration " << std::setw(6) << i << std::scientific << " residu: " << gcg / gcg0; + std::cout << " absolute: " << std::scientific << gcg << '\n'; } if (gcg < relativeEpsilon) { @@ -96,8 +96,8 @@ struct PCG if (gcg > relativeEpsilon) { std::cout << " conjugate gradient: " << rang::fgB::red << "*NOT CONVERGED*" << rang::style::reset << '\n'; std::cout << " - epsilon: " << epsilon << '\n'; - std::cout << " - relative residu : " << gcg / gcg0 << '\n'; - std::cout << " - absolute residu : " << gcg << '\n'; + std::cout << " - relative residu : " << std::scientific << gcg / gcg0 << '\n'; + std::cout << " - absolute residu : " << std::scientific << gcg << '\n'; } } }; diff --git a/src/algebra/CMakeLists.txt b/src/algebra/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d1f01ec310858be0e775514f934ce0a7c0e9da28 --- /dev/null +++ b/src/algebra/CMakeLists.txt @@ -0,0 +1,7 @@ +# ------------------- Source files -------------------- + +add_library( + PugsAlgebra + LinearSolver.cpp + LinearSolverOptions.cpp + PETScWrapper.cpp) diff --git a/src/algebra/CRSMatrix.hpp b/src/algebra/CRSMatrix.hpp index 97a0a8a8ed44d660f9524eac269abee0324cd136..98297d759675fdb7fdbf8be1e638f19cba4cc562 100644 --- a/src/algebra/CRSMatrix.hpp +++ b/src/algebra/CRSMatrix.hpp @@ -17,7 +17,7 @@ class CRSMatrix using MutableDataType = std::remove_const_t<DataType>; private: - using HostMatrix = Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>; + using HostMatrix = Kokkos::StaticCrsGraph<const IndexType, Kokkos::HostSpace>; HostMatrix m_host_matrix; Array<const DataType> m_values; @@ -30,8 +30,27 @@ class CRSMatrix return m_host_matrix.numRows(); } + auto + values() const + { + return m_values; + } + + auto + rowIndices() const + { + return encapsulate(m_host_matrix.row_map); + } + + auto + row(size_t i) const + { + return m_host_matrix.rowConst(i); + } + template <typename DataType2> - Vector<MutableDataType> operator*(const Vector<DataType2>& x) const + Vector<MutableDataType> + operator*(const Vector<DataType2>& x) const { static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), "Cannot multiply matrix and vector of different type"); @@ -57,8 +76,17 @@ class CRSMatrix CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M) { - m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector()); - m_values = M.valueArray(); + { + auto host_matrix = + Kokkos::create_staticcrsgraph<Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>>("connectivity_matrix", + M.graphVector()); + + // This is a bit crappy but it is the price to pay to avoid + m_host_matrix.entries = host_matrix.entries; + m_host_matrix.row_map = host_matrix.row_map; + m_host_matrix.row_block_offsets = host_matrix.row_block_offsets; + } + m_values = M.valueArray(); } ~CRSMatrix() = default; }; diff --git a/src/algebra/LinearSolver.cpp b/src/algebra/LinearSolver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b3a4d9816f6bc6634f73a80b70ba304f82ed60b1 --- /dev/null +++ b/src/algebra/LinearSolver.cpp @@ -0,0 +1,356 @@ +#include <algebra/LinearSolver.hpp> +#include <utils/pugs_config.hpp> + +#include <algebra/BiCGStab.hpp> +#include <algebra/CG.hpp> + +#ifdef PUGS_HAS_PETSC +#include <petsc.h> +#endif // PUGS_HAS_PETSC + +struct LinearSolver::Internals +{ + static bool + hasLibrary(const LSLibrary library) + { + switch (library) { + case LSLibrary::builtin: { + return true; + } + case LSLibrary::petsc: { +#ifdef PUGS_HAS_PETSC + return true; +#else + return false; +#endif + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("Linear system library (" + ::name(library) + ") was not set!"); + } + // LCOV_EXCL_STOP + } + } + + static void + checkHasLibrary(const LSLibrary library) + { + if (not hasLibrary(library)) { + // LCOV_EXCL_START + throw NormalError(::name(library) + " is not linked to pugs. Cannot use it!"); + // LCOV_EXCL_STOP + } + } + + static void + checkBuiltinMethod(const LSMethod method) + { + switch (method) { + case LSMethod::cg: + case LSMethod::bicgstab: { + break; + } + default: { + throw NormalError(name(method) + " is not a builtin linear solver!"); + } + } + } + + static void + checkPETScMethod(const LSMethod method) + { + switch (method) { + case LSMethod::cg: + case LSMethod::bicgstab: + case LSMethod::bicgstab2: + case LSMethod::gmres: + case LSMethod::lu: + case LSMethod::choleski: { + break; + } + // LCOV_EXCL_START + default: { + throw NormalError(name(method) + " is not a builtin linear solver!"); + } + // LCOV_EXCL_STOP + } + } + + static void + checkBuiltinPrecond(const LSPrecond precond) + { + switch (precond) { + case LSPrecond::none: { + break; + } + default: { + throw NormalError(name(precond) + " is not a builtin preconditioner!"); + } + } + } + + static void + checkPETScPrecond(const LSPrecond precond) + { + switch (precond) { + case LSPrecond::none: + case LSPrecond::amg: + case LSPrecond::diagonal: + case LSPrecond::incomplete_choleski: + case LSPrecond::incomplete_LU: { + break; + } + // LCOV_EXCL_START + default: { + throw NormalError(name(precond) + " is not a PETSc preconditioner!"); + } + // LCOV_EXCL_STOP + } + } + + static void + checkOptions(const LinearSolverOptions& options) + { + switch (options.library()) { + case LSLibrary::builtin: { + checkBuiltinMethod(options.method()); + checkBuiltinPrecond(options.precond()); + break; + } + case LSLibrary::petsc: { + checkPETScMethod(options.method()); + checkPETScPrecond(options.precond()); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("undefined options compatibility for this library (" + ::name(options.library()) + ")!"); + } + // LCOV_EXCL_STOP + } + } + + static void + builtinSolveLocalSystem(const CRSMatrix<double, size_t>& A, + Vector<double>& x, + const Vector<double>& b, + const LinearSolverOptions& options) + { + if (options.precond() != LSPrecond::none) { + // LCOV_EXCL_START + throw UnexpectedError("builtin linear solver do not allow any preconditioner!"); + // LCOV_EXCL_STOP + } + switch (options.method()) { + case LSMethod::cg: { + CG{A, x, b, options.epsilon(), options.maximumIteration(), options.verbose()}; + break; + } + case LSMethod::bicgstab: { + BiCGStab{A, x, b, options.epsilon(), options.maximumIteration(), options.verbose()}; + break; + } + // LCOV_EXCL_START + default: { + throw NotImplementedError("undefined builtin method: " + name(options.method())); + } + // LCOV_EXCL_STOP + } + } + +#ifdef PUGS_HAS_PETSC + static int + petscMonitor(KSP, int i, double residu, void*) + { + std::cout << " - iteration: " << std::setw(6) << i << " residu: " << std::scientific << residu << '\n'; + return 0; + } + + static void + petscSolveLocalSystem(const CRSMatrix<double, size_t>& A, + Vector<double>& x, + const Vector<double>& b, + const LinearSolverOptions& options) + { + Assert(x.size() == b.size() and x.size() == A.numberOfRows()); + + Vec petscB; + VecCreateMPIWithArray(PETSC_COMM_WORLD, 1, b.size(), b.size(), &b[0], &petscB); + Vec petscX; + VecCreateMPIWithArray(PETSC_COMM_WORLD, 1, x.size(), x.size(), &x[0], &petscX); + + Array<PetscScalar> values = copy(A.values()); + + const auto A_row_indices = A.rowIndices(); + Array<PetscInt> row_indices{A_row_indices.size()}; + for (size_t i = 0; i < row_indices.size(); ++i) { + row_indices[i] = A_row_indices[i]; + } + + Array<PetscInt> column_indices{values.size()}; + size_t l = 0; + for (size_t i = 0; i < A.numberOfRows(); ++i) { + const auto row_i = A.row(i); + for (size_t j = 0; j < row_i.length; ++j) { + column_indices[l++] = row_i.colidx(j); + } + } + + Mat petscMat; + MatCreateSeqAIJWithArrays(PETSC_COMM_WORLD, x.size(), x.size(), &row_indices[0], &column_indices[0], &values[0], + &petscMat); + + MatAssemblyBegin(petscMat, MAT_FINAL_ASSEMBLY); + MatAssemblyEnd(petscMat, MAT_FINAL_ASSEMBLY); + + KSP ksp; + KSPCreate(PETSC_COMM_WORLD, &ksp); + KSPSetTolerances(ksp, options.epsilon(), 1E-100, 1E5, options.maximumIteration()); + + KSPSetOperators(ksp, petscMat, petscMat); + + PC pc; + KSPGetPC(ksp, &pc); + + bool direct_solver = false; + + switch (options.method()) { + case LSMethod::bicgstab: { + KSPSetType(ksp, KSPBCGS); + break; + } + case LSMethod::bicgstab2: { + KSPSetType(ksp, KSPBCGSL); + KSPBCGSLSetEll(ksp, 2); + break; + } + case LSMethod::cg: { + KSPSetType(ksp, KSPCG); + break; + } + case LSMethod::gmres: { + KSPSetType(ksp, KSPGMRES); + + break; + } + case LSMethod::lu: { + KSPSetType(ksp, KSPPREONLY); + PCSetType(pc, PCLU); + PCFactorSetShiftType(pc, MAT_SHIFT_NONZERO); + direct_solver = true; + break; + } + case LSMethod::choleski: { + KSPSetType(ksp, KSPPREONLY); + PCSetType(pc, PCCHOLESKY); + direct_solver = true; + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("unexpected method: " + name(options.method())); + } + // LCOV_EXCL_STOP + } + + if (not direct_solver) { + switch (options.precond()) { + case LSPrecond::amg: { + PCSetType(pc, PCGAMG); + break; + } + case LSPrecond::diagonal: { + PCSetType(pc, PCJACOBI); + break; + } + case LSPrecond::incomplete_LU: { + PCSetType(pc, PCILU); + break; + } + case LSPrecond::incomplete_choleski: { + PCSetType(pc, PCICC); + break; + } + case LSPrecond::none: { + PCSetType(pc, PCNONE); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("unexpected preconditioner: " + name(options.precond())); + } + // LCOV_EXCL_STOP + } + } + if (options.verbose()) { + KSPMonitorSet(ksp, petscMonitor, 0, 0); + } + + KSPSolve(ksp, petscB, petscX); + + // free used memory + MatDestroy(&petscMat); + VecDestroy(&petscB); + VecDestroy(&petscX); + KSPDestroy(&ksp); + } + +#else // PUGS_HAS_PETSC + + // LCOV_EXCL_START + static void + petscSolveLocalSystem(const CRSMatrix<double, size_t>&, + Vector<double>&, + const Vector<double>&, + const LinearSolverOptions&) + { + checkHasLibrary(LSLibrary::petsc); + throw UnexpectedError("unexpected situation should not reach this point!"); + } + // LCOV_EXCL_STOP + +#endif // PUGS_HAS_PETSC +}; + +bool +LinearSolver::hasLibrary(LSLibrary library) const +{ + return Internals::hasLibrary(library); +} + +void +LinearSolver::checkOptions(const LinearSolverOptions& options) const +{ + Internals::checkOptions(options); +} + +void +LinearSolver::solveLocalSystem(const CRSMatrix<double, size_t>& A, Vector<double>& x, const Vector<double>& b) +{ + switch (m_options.library()) { + case LSLibrary::builtin: { + Internals::builtinSolveLocalSystem(A, x, b, m_options); + break; + } + // LCOV_EXCL_START + case LSLibrary::petsc: { + // not covered since if PETSc is not linked this point is + // unreachable: LinearSolver throws an exception at construction + // in this case. + Internals::petscSolveLocalSystem(A, x, b, m_options); + break; + } + default: { + throw UnexpectedError(::name(m_options.library()) + " cannot solve local systems for sparse matrices"); + } + // LCOV_EXCL_STOP + } +} + +LinearSolver::LinearSolver(const LinearSolverOptions& options) : m_options{options} +{ + Internals::checkHasLibrary(m_options.library()); + Internals::checkOptions(options); +} + +LinearSolver::LinearSolver() : LinearSolver{LinearSolverOptions::default_options} {} diff --git a/src/algebra/LinearSolver.hpp b/src/algebra/LinearSolver.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ed5660bd9ce51c1e5ba5de3fe00c595aa96f1e7 --- /dev/null +++ b/src/algebra/LinearSolver.hpp @@ -0,0 +1,37 @@ +#ifndef LINEAR_SOLVER_HPP +#define LINEAR_SOLVER_HPP + +#include <algebra/CRSMatrix.hpp> +#include <algebra/LinearSolverOptions.hpp> +#include <algebra/TinyMatrix.hpp> +#include <algebra/TinyVector.hpp> +#include <algebra/Vector.hpp> +#include <utils/Exceptions.hpp> + +class LinearSolver +{ + private: + struct Internals; + + const LinearSolverOptions m_options; + + void _solveLocalDense(size_t N, const double* A, double* x, const double* b); + + public: + bool hasLibrary(LSLibrary library) const; + void checkOptions(const LinearSolverOptions& options) const; + + void solveLocalSystem(const CRSMatrix<double, size_t>& A, Vector<double>& x, const Vector<double>& b); + + template <size_t N> + void + solveLocalSystem(const TinyMatrix<N>& A, TinyVector<N>& x, const TinyVector<N>& b) + { + this->_solveLocalDense(N, &A(0, 0), &x[0], &b[0]); + } + + LinearSolver(); + LinearSolver(const LinearSolverOptions& options); +}; + +#endif // LINEAR_SOLVER_HPP diff --git a/src/algebra/LinearSolverOptions.cpp b/src/algebra/LinearSolverOptions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b40c5fa92f3af73015afa133eaae99b91feb6fda --- /dev/null +++ b/src/algebra/LinearSolverOptions.cpp @@ -0,0 +1,15 @@ +#include <algebra/LinearSolverOptions.hpp> + +#include <rang.hpp> + +std::ostream& +operator<<(std::ostream& os, const LinearSolverOptions& options) +{ + os << " library: " << rang::style::bold << name(options.library()) << rang::style::reset << '\n'; + os << " method : " << rang::style::bold << name(options.method()) << rang::style::reset << '\n'; + os << " precond: " << rang::style::bold << name(options.precond()) << rang::style::reset << '\n'; + os << " epsilon: " << rang::style::bold << options.epsilon() << rang::style::reset << '\n'; + os << " maxiter: " << rang::style::bold << options.maximumIteration() << rang::style::reset << '\n'; + os << " verbose: " << rang::style::bold << std::boolalpha << options.verbose() << rang::style::reset << '\n'; + return os; +} diff --git a/src/algebra/LinearSolverOptions.hpp b/src/algebra/LinearSolverOptions.hpp new file mode 100644 index 0000000000000000000000000000000000000000..014ee86d64bee4d2a128e820ab8876450a195369 --- /dev/null +++ b/src/algebra/LinearSolverOptions.hpp @@ -0,0 +1,237 @@ +#ifndef LINEAR_SOLVER_OPTIONS_HPP +#define LINEAR_SOLVER_OPTIONS_HPP + +#include <utils/Exceptions.hpp> + +#include <iostream> + +enum class LSLibrary : int8_t +{ + LS__begin = 0, + // + builtin = LS__begin, + petsc, + // + LS__end +}; + +enum class LSMethod : int8_t +{ + LS__begin = 0, + // + cg = LS__begin, + bicgstab, + bicgstab2, + gmres, + lu, + choleski, + // + LS__end +}; + +enum class LSPrecond : int8_t +{ + LS__begin = 0, + // + none = LS__begin, + diagonal, + incomplete_choleski, + incomplete_LU, + amg, + // + LS__end +}; + +inline std::string +name(const LSLibrary library) +{ + switch (library) { + case LSLibrary::builtin: { + return "builtin"; + } + case LSLibrary::petsc: { + return "PETSc"; + } + case LSLibrary::LS__end: { + } + } + throw UnexpectedError("Linear system library name is not defined!"); +} + +inline std::string +name(const LSMethod method) +{ + switch (method) { + case LSMethod::cg: { + return "CG"; + } + case LSMethod::bicgstab: { + return "BICGStab"; + } + case LSMethod::bicgstab2: { + return "BICGStab2"; + } + case LSMethod::gmres: { + return "GMRES"; + } + case LSMethod::lu: { + return "LU"; + } + case LSMethod::choleski: { + return "Choleski"; + } + case LSMethod::LS__end: { + } + } + throw UnexpectedError("Linear system method name is not defined!"); +} + +inline std::string +name(const LSPrecond precond) +{ + switch (precond) { + case LSPrecond::none: { + return "none"; + } + case LSPrecond::diagonal: { + return "diagonal"; + } + case LSPrecond::incomplete_choleski: { + return "ICholeski"; + } + case LSPrecond::incomplete_LU: { + return "ILU"; + } + case LSPrecond::amg: { + return "AMG"; + } + case LSPrecond::LS__end: { + } + } + throw UnexpectedError("Linear system preconditioner name is not defined!"); +} + +template <typename LSEnumType> +inline LSEnumType +getLSEnumFromName(const std::string& enum_name) +{ + using BaseT = std::underlying_type_t<LSEnumType>; + for (BaseT enum_value = static_cast<BaseT>(LSEnumType::LS__begin); + enum_value < static_cast<BaseT>(LSEnumType::LS__end); ++enum_value) { + if (name(LSEnumType{enum_value}) == enum_name) { + return LSEnumType{enum_value}; + } + } + throw NormalError(std::string{"could not find '"} + enum_name + "' associate type!"); +} + +template <typename LSEnumType> +inline void +printLSEnumListNames(std::ostream& os) +{ + using BaseT = std::underlying_type_t<LSEnumType>; + for (BaseT enum_value = static_cast<BaseT>(LSEnumType::LS__begin); + enum_value < static_cast<BaseT>(LSEnumType::LS__end); ++enum_value) { + os << " - " << name(LSEnumType{enum_value}) << '\n'; + } +} + +class LinearSolverOptions +{ + private: + LSLibrary m_library = LSLibrary::builtin; + LSMethod m_method = LSMethod::bicgstab; + LSPrecond m_precond = LSPrecond::none; + + double m_epsilon = 1E-6; + size_t m_maximum_iteration = 200; + + bool m_verbose = false; + + public: + static LinearSolverOptions default_options; + + friend std::ostream& operator<<(std::ostream& os, const LinearSolverOptions& options); + + LSLibrary& + library() + { + return m_library; + } + + LSLibrary + library() const + { + return m_library; + } + + LSMethod + method() const + { + return m_method; + } + + LSMethod& + method() + { + return m_method; + } + + LSPrecond + precond() const + { + return m_precond; + } + + LSPrecond& + precond() + { + return m_precond; + } + + double + epsilon() const + { + return m_epsilon; + } + + double& + epsilon() + { + return m_epsilon; + } + + size_t& + maximumIteration() + { + return m_maximum_iteration; + } + + size_t + maximumIteration() const + { + return m_maximum_iteration; + } + + bool& + verbose() + { + return m_verbose; + }; + + bool + verbose() const + { + return m_verbose; + }; + + LinearSolverOptions(const LinearSolverOptions&) = default; + LinearSolverOptions(LinearSolverOptions&&) = default; + + LinearSolverOptions() = default; + ~LinearSolverOptions() = default; +}; + +inline LinearSolverOptions LinearSolverOptions::default_options; + +#endif // LINEAR_SOLVER_OPTIONS_HPP diff --git a/src/algebra/PETScWrapper.cpp b/src/algebra/PETScWrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd11dc9772f11ad29ca4970c5b88656c623a3508 --- /dev/null +++ b/src/algebra/PETScWrapper.cpp @@ -0,0 +1,27 @@ +#include <algebra/PETScWrapper.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_PETSC +#include <petsc.h> +#endif // PUGS_HAS_PETSC + +namespace PETScWrapper +{ +void +initialize([[maybe_unused]] int& argc, [[maybe_unused]] char* argv[]) +{ +#ifdef PUGS_HAS_PETSC + PetscOptionsSetValue(NULL, "-no_signal_handler", "true"); + PetscInitialize(&argc, &argv, 0, 0); +#endif // PUGS_HAS_PETSC +} + +void +finalize() +{ +#ifdef PUGS_HAS_PETSC + PetscFinalize(); +#endif // PUGS_HAS_PETSC +} +} // namespace PETScWrapper diff --git a/src/algebra/PETScWrapper.hpp b/src/algebra/PETScWrapper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2f40b37f8c4562b7abeccffeb6e8cfb27688c7c4 --- /dev/null +++ b/src/algebra/PETScWrapper.hpp @@ -0,0 +1,10 @@ +#ifndef PETSC_WRAPPER_HPP +#define PETSC_WRAPPER_HPP + +namespace PETScWrapper +{ +void initialize(int& argc, char* argv[]); +void finalize(); +} // namespace PETScWrapper + +#endif // PETSC_WRAPPER_HPP diff --git a/src/algebra/SparseMatrixDescriptor.hpp b/src/algebra/SparseMatrixDescriptor.hpp index 100000366d3bd0e5a072084864e82cdb4b73ba29..f63a17576d619bf692f36a00e5c0c4e373effdbf 100644 --- a/src/algebra/SparseMatrixDescriptor.hpp +++ b/src/algebra/SparseMatrixDescriptor.hpp @@ -145,7 +145,12 @@ class SparseMatrixDescriptor return values; } - SparseMatrixDescriptor(size_t nb_row) : m_row_array{nb_row} {} + SparseMatrixDescriptor(size_t nb_row) : m_row_array{nb_row} + { + for (size_t i = 0; i < nb_row; ++i) { + m_row_array[i][i] = 0; + } + } ~SparseMatrixDescriptor() = default; }; diff --git a/src/algebra/TinyMatrix.hpp b/src/algebra/TinyMatrix.hpp index c6cf4f8165e2f7a4eb812b85d0ee4097b197169c..701bbb2349361ae2d10532a43aa373f0fc98fb48 100644 --- a/src/algebra/TinyMatrix.hpp +++ b/src/algebra/TinyMatrix.hpp @@ -11,7 +11,7 @@ #include <iostream> template <size_t N, typename T = double> -class TinyMatrix +class [[nodiscard]] TinyMatrix { public: using data_type = T; @@ -21,15 +21,13 @@ class TinyMatrix static_assert((N > 0), "TinyMatrix size must be strictly positive"); PUGS_FORCEINLINE - constexpr size_t - _index(size_t i, size_t j) const noexcept // LCOV_EXCL_LINE (due to forced inline) + constexpr size_t _index(size_t i, size_t j) const noexcept // LCOV_EXCL_LINE (due to forced inline) { return i * N + j; } template <typename... Args> - PUGS_FORCEINLINE constexpr void - _unpackVariadicInput(const T& t, Args&&... args) noexcept + PUGS_FORCEINLINE constexpr void _unpackVariadicInput(const T& t, Args&&... args) noexcept { m_values[N * N - 1 - sizeof...(args)] = t; if constexpr (sizeof...(args) > 0) { @@ -39,8 +37,25 @@ class TinyMatrix public: PUGS_INLINE - constexpr TinyMatrix - operator-() const + constexpr size_t dimension() const + { + return N * N; + } + + PUGS_INLINE + constexpr size_t nbRows() const + { + return N; + } + + PUGS_INLINE + constexpr size_t nbColumns() const + { + return N; + } + + PUGS_INLINE + constexpr TinyMatrix operator-() const { TinyMatrix opposed; for (size_t i = 0; i < N * N; ++i) { @@ -63,8 +78,7 @@ class TinyMatrix } PUGS_INLINE - constexpr TinyMatrix& - operator*=(const T& t) + constexpr TinyMatrix& operator*=(const T& t) { for (size_t i = 0; i < N * N; ++i) { m_values[i] *= t; @@ -105,28 +119,23 @@ class TinyMatrix } PUGS_INLINE - constexpr friend std::ostream& - operator<<(std::ostream& os, const TinyMatrix& A) + constexpr friend std::ostream& operator<<(std::ostream& os, const TinyMatrix& A) { - if constexpr (N == 1) { - os << A(0, 0); - } else { - os << '['; - for (size_t i = 0; i < N; ++i) { - os << '(' << A(i, 0); - for (size_t j = 1; j < N; ++j) { - os << ',' << A(i, j); - } - os << ')'; + os << '['; + for (size_t i = 0; i < N; ++i) { + os << '(' << A(i, 0); + for (size_t j = 1; j < N; ++j) { + os << ',' << A(i, j); } - os << ']'; + os << ')'; } + os << ']'; + return os; } PUGS_INLINE - constexpr bool - operator==(const TinyMatrix& A) const + constexpr bool operator==(const TinyMatrix& A) const { for (size_t i = 0; i < N * N; ++i) { if (m_values[i] != A.m_values[i]) @@ -136,15 +145,13 @@ class TinyMatrix } PUGS_INLINE - constexpr bool - operator!=(const TinyMatrix& A) const + constexpr bool operator!=(const TinyMatrix& A) const { return not this->operator==(A); } PUGS_INLINE - constexpr TinyMatrix - operator+(const TinyMatrix& A) const + constexpr TinyMatrix operator+(const TinyMatrix& A) const { TinyMatrix sum; for (size_t i = 0; i < N * N; ++i) { @@ -154,8 +161,7 @@ class TinyMatrix } PUGS_INLINE - constexpr TinyMatrix - operator+(TinyMatrix&& A) const + constexpr TinyMatrix operator+(TinyMatrix&& A) const { for (size_t i = 0; i < N * N; ++i) { A.m_values[i] += m_values[i]; @@ -164,8 +170,7 @@ class TinyMatrix } PUGS_INLINE - constexpr TinyMatrix - operator-(const TinyMatrix& A) const + constexpr TinyMatrix operator-(const TinyMatrix& A) const { TinyMatrix difference; for (size_t i = 0; i < N * N; ++i) { @@ -175,8 +180,7 @@ class TinyMatrix } PUGS_INLINE - constexpr TinyMatrix - operator-(TinyMatrix&& A) const + constexpr TinyMatrix operator-(TinyMatrix&& A) const { for (size_t i = 0; i < N * N; ++i) { A.m_values[i] = m_values[i] - A.m_values[i]; @@ -185,8 +189,7 @@ class TinyMatrix } PUGS_INLINE - constexpr TinyMatrix& - operator+=(const TinyMatrix& A) + constexpr TinyMatrix& operator+=(const TinyMatrix& A) { for (size_t i = 0; i < N * N; ++i) { m_values[i] += A.m_values[i]; @@ -195,8 +198,7 @@ class TinyMatrix } PUGS_INLINE - constexpr void - operator+=(const volatile TinyMatrix& A) volatile + constexpr void operator+=(const volatile TinyMatrix& A) volatile { for (size_t i = 0; i < N * N; ++i) { m_values[i] += A.m_values[i]; @@ -204,8 +206,7 @@ class TinyMatrix } PUGS_INLINE - constexpr TinyMatrix& - operator-=(const TinyMatrix& A) + constexpr TinyMatrix& operator-=(const TinyMatrix& A) { for (size_t i = 0; i < N * N; ++i) { m_values[i] -= A.m_values[i]; @@ -214,16 +215,14 @@ class TinyMatrix } PUGS_INLINE - constexpr T& - operator()(size_t i, size_t j) noexcept(NO_ASSERT) + constexpr T& operator()(size_t i, size_t j) noexcept(NO_ASSERT) { Assert((i < N) and (j < N)); return m_values[_index(i, j)]; } PUGS_INLINE - constexpr const T& - operator()(size_t i, size_t j) const noexcept(NO_ASSERT) + constexpr const T& operator()(size_t i, size_t j) const noexcept(NO_ASSERT) { Assert((i < N) and (j < N)); return m_values[_index(i, j)]; @@ -295,7 +294,7 @@ class TinyMatrix constexpr TinyMatrix(const TinyMatrix&) noexcept = default; PUGS_INLINE - TinyMatrix(TinyMatrix&& A) noexcept = default; + TinyMatrix(TinyMatrix && A) noexcept = default; PUGS_INLINE ~TinyMatrix() = default; @@ -342,14 +341,19 @@ det(const TinyMatrix<N, T>& A) determinent *= -1; } } - const size_t I = index[i]; - const T inv_Mii = 1. / M(I, i); - for (size_t k = i + 1; k < N; ++k) { - const size_t K = index[k]; - const T factor = M(K, i) * inv_Mii; - for (size_t l = i + 1; l < N; ++l) { - M(K, l) -= factor * M(I, l); + const size_t I = index[i]; + const T Mii = M(I, i); + if (Mii != 0) { + const T inv_Mii = 1. / M(I, i); + for (size_t k = i + 1; k < N; ++k) { + const size_t K = index[k]; + const T factor = M(K, i) * inv_Mii; + for (size_t l = i + 1; l < N; ++l) { + M(K, l) -= factor * M(I, l); + } } + } else { + return 0; } } diff --git a/src/algebra/TinyVector.hpp b/src/algebra/TinyVector.hpp index 07690a826ef13040e42baca75daac425fa199333..3f3c10920f5ae75c6b37a0761393cad19a77786b 100644 --- a/src/algebra/TinyVector.hpp +++ b/src/algebra/TinyVector.hpp @@ -11,7 +11,7 @@ #include <cmath> template <size_t N, typename T = double> -class TinyVector +class [[nodiscard]] TinyVector { public: inline static constexpr size_t Dimension = N; @@ -22,8 +22,7 @@ class TinyVector static_assert((N > 0), "TinyVector size must be strictly positive"); template <typename... Args> - PUGS_FORCEINLINE constexpr void - _unpackVariadicInput(const T& t, Args&&... args) noexcept + PUGS_FORCEINLINE constexpr void _unpackVariadicInput(const T& t, Args&&... args) noexcept { m_values[N - 1 - sizeof...(args)] = t; if constexpr (sizeof...(args) > 0) { @@ -33,8 +32,7 @@ class TinyVector public: PUGS_INLINE - constexpr TinyVector - operator-() const + constexpr TinyVector operator-() const { TinyVector opposed; for (size_t i = 0; i < N; ++i) { @@ -44,15 +42,13 @@ class TinyVector } PUGS_INLINE - constexpr size_t - dimension() const + constexpr size_t dimension() const { return N; } PUGS_INLINE - constexpr bool - operator==(const TinyVector& v) const + constexpr bool operator==(const TinyVector& v) const { for (size_t i = 0; i < N; ++i) { if (m_values[i] != v.m_values[i]) @@ -62,8 +58,7 @@ class TinyVector } PUGS_INLINE - constexpr bool - operator!=(const TinyVector& v) const + constexpr bool operator!=(const TinyVector& v) const { return not this->operator==(v); } @@ -79,8 +74,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector& - operator*=(const T& t) + constexpr TinyVector& operator*=(const T& t) { for (size_t i = 0; i < N; ++i) { m_values[i] *= t; @@ -103,8 +97,7 @@ class TinyVector } PUGS_INLINE - constexpr friend std::ostream& - operator<<(std::ostream& os, const TinyVector& v) + constexpr friend std::ostream& operator<<(std::ostream& os, const TinyVector& v) { os << '(' << v.m_values[0]; for (size_t i = 1; i < N; ++i) { @@ -115,8 +108,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector - operator+(const TinyVector& v) const + constexpr TinyVector operator+(const TinyVector& v) const { TinyVector sum; for (size_t i = 0; i < N; ++i) { @@ -126,8 +118,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector - operator+(TinyVector&& v) const + constexpr TinyVector operator+(TinyVector&& v) const { for (size_t i = 0; i < N; ++i) { v.m_values[i] += m_values[i]; @@ -136,8 +127,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector - operator-(const TinyVector& v) const + constexpr TinyVector operator-(const TinyVector& v) const { TinyVector difference; for (size_t i = 0; i < N; ++i) { @@ -147,8 +137,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector - operator-(TinyVector&& v) const + constexpr TinyVector operator-(TinyVector&& v) const { for (size_t i = 0; i < N; ++i) { v.m_values[i] = m_values[i] - v.m_values[i]; @@ -157,8 +146,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector& - operator+=(const TinyVector& v) + constexpr TinyVector& operator+=(const TinyVector& v) { for (size_t i = 0; i < N; ++i) { m_values[i] += v.m_values[i]; @@ -167,8 +155,7 @@ class TinyVector } PUGS_INLINE - constexpr void - operator+=(const volatile TinyVector& v) volatile + constexpr void operator+=(const volatile TinyVector& v) volatile { for (size_t i = 0; i < N; ++i) { m_values[i] += v.m_values[i]; @@ -176,8 +163,7 @@ class TinyVector } PUGS_INLINE - constexpr TinyVector& - operator-=(const TinyVector& v) + constexpr TinyVector& operator-=(const TinyVector& v) { for (size_t i = 0; i < N; ++i) { m_values[i] -= v.m_values[i]; @@ -241,7 +227,7 @@ class TinyVector constexpr TinyVector(const TinyVector&) noexcept = default; PUGS_INLINE - constexpr TinyVector(TinyVector&& v) noexcept = default; + constexpr TinyVector(TinyVector && v) noexcept = default; PUGS_INLINE ~TinyVector() noexcept = default; diff --git a/src/language/PEGGrammar.hpp b/src/language/PEGGrammar.hpp index 61ba912a567c4a6d868d22cc7915e7557f484518..beae437f0e5d0f9231b51afc0671e29e7ecfef30 100644 --- a/src/language/PEGGrammar.hpp +++ b/src/language/PEGGrammar.hpp @@ -55,11 +55,13 @@ struct character : if_must_else< one< '\\' >, escaped_c, ascii::any> {}; struct open_parent : seq< one< '(' >, ignored > {}; struct close_parent : seq< one< ')' >, ignored > {}; -struct literal : if_must< one< '"' >, until< one< '"' >, character > > {}; +struct literal : star< minus<character, one < '"' > > >{}; + +struct quoted_literal : if_must< one< '"' >, seq< literal, one< '"' > > >{}; struct import_kw : TAO_PEGTL_KEYWORD("import") {}; -struct LITERAL : seq< literal, ignored >{}; +struct LITERAL : seq< quoted_literal, ignored >{}; struct REAL : seq< real, ignored >{}; @@ -73,11 +75,12 @@ struct string_type : TAO_PEGTL_KEYWORD("string") {}; struct scalar_type : sor< B_set, R_set, Z_set, N_set >{}; struct vector_type : seq< R_set, ignored, one< '^' >, ignored, integer >{}; +struct matrix_type : seq< R_set, ignored, one< '^' >, ignored, integer, ignored, one< 'x' >, ignored, integer >{}; struct basic_type : sor< scalar_type, string_type >{}; struct type_name_id; -struct simple_type_specifier : sor< vector_type, basic_type, type_name_id >{}; +struct simple_type_specifier : sor< matrix_type, vector_type, basic_type, type_name_id >{}; struct tuple_type_specifier : sor<try_catch< open_parent, simple_type_specifier, ignored, close_parent >, // non matching braces management @@ -178,7 +181,7 @@ struct postfix_operator : seq< sor< post_plusplus, post_minusminus>, ignored > { struct open_bracket : seq< one< '[' >, ignored > {}; struct close_bracket : seq< one< ']' >, ignored > {}; -struct subscript_expression : if_must< open_bracket, expression, close_bracket >{}; +struct subscript_expression : if_must< open_bracket, list_must<expression, COMMA>, close_bracket >{}; struct postfix_expression : seq< primary_expression, star< sor< subscript_expression , postfix_operator> > >{}; @@ -229,7 +232,10 @@ struct expression : logical_or {}; struct tuple_expression : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{}; -struct expression_list : seq< open_parent, sor< tuple_expression, expression >, plus< if_must< COMMA, sor< tuple_expression, expression > > >, close_parent >{}; +struct expression_list : seq< open_parent, sor< seq< tuple_expression, + star< if_must< COMMA, sor< tuple_expression, expression > > > >, + seq< expression, + plus< if_must< COMMA, sor< tuple_expression, expression > > > > >, close_parent >{}; struct affect_op : sor< eq_op, multiplyeq_op, divideeq_op, pluseq_op, minuseq_op > {}; diff --git a/src/language/PugsParser.cpp b/src/language/PugsParser.cpp index 75842ad029795a579a213b2d50bed15bcfb68b30..faab47780f0736c43285b3245a8ee07cc966f036 100644 --- a/src/language/PugsParser.cpp +++ b/src/language/PugsParser.cpp @@ -14,7 +14,9 @@ #include <language/ast/ASTSymbolInitializationChecker.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> #include <language/utils/ASTDotPrinter.hpp> +#include <language/utils/ASTExecutionInfo.hpp> #include <language/utils/ASTPrinter.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/SymbolTable.hpp> #include <utils/PugsAssert.hpp> #include <utils/PugsUtils.hpp> @@ -37,15 +39,21 @@ parser(const std::string& filename) { const size_t grammar_issues = analyze<language::grammar>(); - std::cout << rang::fgB::yellow << "grammar_issues=" << rang::fg::reset << grammar_issues << '\n'; + if (grammar_issues != 0) { + std::ostringstream os; + os << "invalid grammar: " << rang::fgB::yellow << grammar_issues << rang::fg::reset << " were detected!"; + throw UnexpectedError(os.str()); + } std::cout << rang::style::bold << "Parsing file " << rang::style::reset << rang::style::underline << filename << rang::style::reset << " ...\n"; auto parse_and_execute = [](auto& input) { + OperatorRepository::create(); + std::unique_ptr<ASTNode> root_node = ASTBuilder::build(input); - ASTModulesImporter{*root_node}; + ASTModulesImporter module_importer{*root_node}; ASTNodeTypeCleaner<language::import_instruction>{*root_node}; ASTSymbolTableBuilder{*root_node}; @@ -64,25 +72,20 @@ parser(const std::string& filename) ASTNodeTypeCleaner<language::var_declaration>{*root_node}; ASTNodeTypeCleaner<language::fct_declaration>{*root_node}; - { - std::string dot_filename{"parse_tree.dot"}; - std::ofstream fout(dot_filename); - ASTDotPrinter dot_printer{*root_node}; - fout << dot_printer; - std::cout << " AST dot file: " << dot_filename << '\n'; - } - ASTNodeEmptyBlockCleaner{*root_node}; ASTNodeExpressionBuilder{*root_node}; + std::cout << "-------------------------------------------------------\n"; + std::cout << rang::style::bold << "Executing AST..." << rang::style::reset << '\n'; - std::cout << ASTPrinter{*root_node} << '\n'; + ASTExecutionInfo execution_info{*root_node, module_importer.moduleRepository()}; ExecutionPolicy exec_all; root_node->execute(exec_all); - std::cout << *(root_node->m_symbol_table) << '\n'; root_node->m_symbol_table->clearValues(); + + OperatorRepository::destroy(); }; if (not SignalManager::pauseOnError()) { @@ -97,7 +100,7 @@ parser(const std::string& filename) << rang::fgB::red << "error: " << rang::fg::reset << rang::style::bold << e.what() << rang::style::reset << '\n' << input.line_at(p) << '\n' - << std::string(p.column, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << '\n'; + << std::string(p.column - 1, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << '\n'; finalize(); std::exit(1); } @@ -105,5 +108,4 @@ parser(const std::string& filename) read_input input(filename); parse_and_execute(input); } - std::cout << "Executed successfuly: " << filename << '\n'; } diff --git a/src/language/ast/ASTBuilder.cpp b/src/language/ast/ASTBuilder.cpp index cef779ffce63b9c2ad808ca94b993e7eff59d4ff..ccf0f42db5c4144ef91475d3fa9c3c24520d1e35 100644 --- a/src/language/ast/ASTBuilder.cpp +++ b/src/language/ast/ASTBuilder.cpp @@ -108,21 +108,19 @@ struct ASTBuilder::simplify_unary : parse_tree::apply<ASTBuilder::simplify_unary } if (n->is_type<language::unary_expression>() or n->is_type<language::name_subscript_expression>()) { - const size_t child_nb = n->children.size(); - if (child_nb > 1) { + if (n->children.size() > 1) { if (n->children[1]->is_type<language::subscript_expression>()) { - auto expression = std::move(n->children[0]); - for (size_t i = 0; i < child_nb - 1; ++i) { - n->children[i] = std::move(n->children[i + 1]); - } + std::swap(n->children[0], n->children[1]); - auto& array_subscript_expression = n->children[0]; + n->children[0]->emplace_back(std::move(n->children[1])); n->children.pop_back(); - array_subscript_expression->children.emplace_back(std::move(expression)); - std::swap(array_subscript_expression->children[0], array_subscript_expression->children[1]); - - array_subscript_expression->m_begin = array_subscript_expression->children[0]->m_begin; + auto& array_subscript_expression = n->children[0]; + const size_t child_nb = array_subscript_expression->children.size(); + for (size_t i = 1; i < array_subscript_expression->children.size(); ++i) { + std::swap(array_subscript_expression->children[child_nb - i], + array_subscript_expression->children[child_nb - i - 1]); + } transform(n, st...); } @@ -138,7 +136,7 @@ struct ASTBuilder::simplify_node_list : parse_tree::apply<ASTBuilder::simplify_n transform(std::unique_ptr<ASTNode>& n, States&&... st) { if (n->is_type<language::name_list>() or n->is_type<language::lvalue_list>() or - n->is_type<language::function_argument_list>() or n->is_type<language::expression_list>()) { + n->is_type<language::function_argument_list>()) { if (n->children.size() == 1) { n = std::move(n->children.back()); transform(n, st...); @@ -247,6 +245,7 @@ using selector = parse_tree::selector< language::type_name_id, language::tuple_expression, language::vector_type, + language::matrix_type, language::string_type, language::cout_kw, language::cerr_kw, @@ -255,6 +254,7 @@ using selector = parse_tree::selector< language::fct_declaration, language::type_mapping, language::function_definition, + language::expression_list, language::if_statement, language::do_while_statement, language::while_statement, @@ -303,8 +303,7 @@ using selector = parse_tree::selector< language::post_plusplus>, ASTBuilder::simplify_for_statement_block::on<language::for_statement_block>, parse_tree::discard_empty::on<language::ignored, language::semicol, language::block>, - ASTBuilder::simplify_node_list:: - on<language::name_list, language::lvalue_list, language::function_argument_list, language::expression_list>, + ASTBuilder::simplify_node_list::on<language::name_list, language::lvalue_list, language::function_argument_list>, ASTBuilder::simplify_statement_block::on<language::statement_block>, ASTBuilder::simplify_for_init::on<language::for_init>, ASTBuilder::simplify_for_test::on<language::for_test>, diff --git a/src/language/ast/ASTModulesImporter.cpp b/src/language/ast/ASTModulesImporter.cpp index 197f9dcdd33a7e268a550cea12464567ff0f93b5..d52b964a8a670ca3af3d0bcd0f3223798e2502c3 100644 --- a/src/language/ast/ASTModulesImporter.cpp +++ b/src/language/ast/ASTModulesImporter.cpp @@ -37,6 +37,8 @@ ASTModulesImporter::_importAllModules(ASTNode& node) ASTModulesImporter::ASTModulesImporter(ASTNode& root_node) : m_symbol_table{*root_node.m_symbol_table} { Assert(root_node.is_root()); + m_module_repository.populateMandatorySymbolTable(root_node, m_symbol_table); + this->_importAllModules(root_node); std::cout << " - loaded modules\n"; } diff --git a/src/language/ast/ASTModulesImporter.hpp b/src/language/ast/ASTModulesImporter.hpp index 2b8c081fdee5ef48dd9594a79ceeca13e52fdedc..566155c4a19ecd00f0b049a29030c0036f212d8a 100644 --- a/src/language/ast/ASTModulesImporter.hpp +++ b/src/language/ast/ASTModulesImporter.hpp @@ -20,6 +20,12 @@ class ASTModulesImporter void _importAllModules(ASTNode& node); public: + const ModuleRepository& + moduleRepository() const + { + return m_module_repository; + } + ASTModulesImporter(ASTNode& root_node); ASTModulesImporter(const ASTModulesImporter&) = delete; diff --git a/src/language/ast/ASTNode.hpp b/src/language/ast/ASTNode.hpp index b3dd41ca65faae6324d040aa7f4b66c078f45f6c..d9568c612dd3e3a6223994cff6273b43d900b3c0 100644 --- a/src/language/ast/ASTNode.hpp +++ b/src/language/ast/ASTNode.hpp @@ -1,9 +1,9 @@ #ifndef AST_NODE_HPP #define AST_NODE_HPP -#include <language/ast/ASTNodeDataType.hpp> #include <language/node_processor/ExecutionPolicy.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/utils/DataVariant.hpp> #include <utils/PugsAssert.hpp> #include <utils/PugsMacros.hpp> @@ -45,7 +45,7 @@ class ASTNode : public parse_tree::basic_node<ASTNode> std::shared_ptr<SymbolTable> m_symbol_table; std::unique_ptr<INodeProcessor> m_node_processor; - ASTNodeDataType m_data_type{ASTNodeDataType::undefined_t}; + ASTNodeDataType m_data_type; [[nodiscard]] PUGS_INLINE std::string string() const diff --git a/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp index e27bd531d743b410ecb381e665bb4af145e4731f..884711c081ded349d32663f80180c06c4911bfba 100644 --- a/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeAffectationExpressionBuilder.cpp @@ -2,434 +2,44 @@ #include <algebra/TinyVector.hpp> #include <language/PEGGrammar.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> -#include <language/node_processor/AffectationProcessor.hpp> +#include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/AffectationMangler.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> - #include <utils/Exceptions.hpp> ASTNodeAffectationExpressionBuilder::ASTNodeAffectationExpressionBuilder(ASTNode& n) { - auto set_affectation_processor = [](ASTNode& n, const auto& operator_v) { - auto set_affectation_processor_for_data = [&](const auto& value, const ASTNodeDataType& data_type) { - using OperatorT = std::decay_t<decltype(operator_v)>; - using ValueT = std::decay_t<decltype(value)>; - - switch (data_type) { - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, double>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: undefined operand type for affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - }; - - auto set_affectation_processor_for_vector_data = [&](const auto& value, const ASTNodeDataType& data_type) { - using OperatorT = std::decay_t<decltype(operator_v)>; - using ValueT = std::decay_t<decltype(value)>; - - if constexpr (std::is_same_v<OperatorT, language::eq_op>) { - switch (data_type) { - case ASTNodeDataType::vector_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, ValueT>>(n); - break; - } - case ASTNodeDataType::list_t: { - n.m_node_processor = std::make_unique<AffectationToTinyVectorFromListProcessor<OperatorT, ValueT>>(n); - break; - } - case ASTNodeDataType::bool_t: { - if constexpr (std::is_same_v<ValueT, TinyVector<1>>) { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, bool>>(n); - break; - } - } - case ASTNodeDataType::unsigned_int_t: { - if constexpr (std::is_same_v<ValueT, TinyVector<1>>) { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, uint64_t>>(n); - break; - } - } - case ASTNodeDataType::int_t: { - if constexpr (std::is_same_v<ValueT, TinyVector<1>>) { - if (n.children[1]->is_type<language::integer>()) { - if (std::stoi(n.children[1]->string()) == 0) { - n.m_node_processor = std::make_unique<AffectationFromZeroProcessor<ValueT>>(n); - break; - } - } - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, int64_t>>(n); - break; - } else if (n.children[1]->is_type<language::integer>()) { - if (std::stoi(n.children[1]->string()) == 0) { - n.m_node_processor = std::make_unique<AffectationFromZeroProcessor<ValueT>>(n); - break; - } - } - // LCOV_EXCL_START - throw ParseError("unexpected error: invalid integral value", std::vector{n.children[1]->begin()}); - // LCOV_EXCL_STOP - } - case ASTNodeDataType::double_t: { - if constexpr (std::is_same_v<ValueT, TinyVector<1>>) { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, double>>(n); - break; - } - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid operand type", std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - } else if constexpr (std::is_same_v<OperatorT, language::pluseq_op> or - std::is_same_v<OperatorT, language::minuseq_op>) { - switch (data_type) { - case ASTNodeDataType::vector_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, ValueT>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid operand type", std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - } else if constexpr (std::is_same_v<OperatorT, language::multiplyeq_op>) { - switch (data_type) { - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, ValueT, double>>(n); - break; - } - default: { - throw ParseError("expecting scalar operand type", std::vector{n.children[1]->begin()}); - } - } - } else { - throw ParseError("invalid affectation operator for " + dataTypeName(n.m_data_type), std::vector{n.begin()}); - } - }; - - auto set_affectation_processor_for_string_data = [&](const ASTNodeDataType& data_type) { - using OperatorT = std::decay_t<decltype(operator_v)>; - - if constexpr (std::is_same_v<OperatorT, language::eq_op> or std::is_same_v<OperatorT, language::pluseq_op>) { - switch (data_type) { - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, double>>(n); - break; - } - case ASTNodeDataType::string_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, std::string>>(n); - break; - } - case ASTNodeDataType::vector_t: { - switch (data_type.dimension()) { - case 1: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, TinyVector<1>>>(n); - break; - } - case 2: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, TinyVector<2>>>(n); - break; - } - case 3: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, std::string, TinyVector<3>>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension for string affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: undefined operand type for string affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - } else { - throw ParseError("invalid affectation operator for string", std::vector{n.begin()}); - } - }; - - auto set_affectation_processor_for_embedded_data = [&](const ASTNodeDataType& data_type) { - using OperatorT = std::decay_t<decltype(operator_v)>; - - if constexpr (std::is_same_v<OperatorT, language::eq_op>) { - switch (data_type) { - case ASTNodeDataType::type_id_t: { - n.m_node_processor = std::make_unique<AffectationProcessor<OperatorT, EmbeddedData, EmbeddedData>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: undefined operand type for embedded data affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - } else { - throw ParseError("invalid affectation operator for '" + dataTypeName(n.children[0]->m_data_type) + "'", - std::vector{n.begin()}); - } - }; - - auto set_affectation_processor_for_tuple_data = [&](const ASTNodeDataType& content_data_type, - const ASTNodeDataType& data_type) { - using OperatorT = std::decay_t<decltype(operator_v)>; - if constexpr (std::is_same_v<OperatorT, language::eq_op>) { - if ((data_type == ASTNodeDataType::list_t) or (data_type == ASTNodeDataType::tuple_t)) { - switch (content_data_type) { - case ASTNodeDataType::type_id_t: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, EmbeddedData>>(n); - break; - } - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, double>>(n); - break; - } - case ASTNodeDataType::string_t: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, std::string>>(n); - break; - } - case ASTNodeDataType::vector_t: { - switch (content_data_type.dimension()) { - case 1: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, TinyVector<1>>>(n); - break; - } - case 2: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, TinyVector<2>>>(n); - break; - } - case 3: { - n.m_node_processor = std::make_unique<AffectationToTupleFromListProcessor<OperatorT, TinyVector<3>>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension for tuple affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - break; - } - // LCOV_EXCL_START - default: { - throw UnexpectedError("invalid tuple content " + dataTypeName(content_data_type)); - } - // LCOV_EXCL_STOP - } - } else { - switch (content_data_type) { - case ASTNodeDataType::type_id_t: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, EmbeddedData>>(n); - break; - } - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, double>>(n); - break; - } - case ASTNodeDataType::string_t: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, std::string>>(n); - break; - } - case ASTNodeDataType::vector_t: { - switch (content_data_type.dimension()) { - case 1: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, TinyVector<1>>>(n); - break; - } - case 2: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, TinyVector<2>>>(n); - break; - } - case 3: { - n.m_node_processor = std::make_unique<AffectationToTupleProcessor<OperatorT, TinyVector<3>>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension for tuple affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: undefined operand type for tuple affectation", - std::vector{n.children[1]->begin()}); - } - // LCOV_EXCL_STOP - } - } - } else { - throw ParseError("invalid affectation operator for '" + dataTypeName(n.children[0]->m_data_type) + "'", - std::vector{n.begin()}); - } - }; - - auto set_affectation_processor_for_value = [&](const ASTNodeDataType& value_type) { - const ASTNodeDataType data_type = n.children[1]->m_data_type; - - switch (value_type) { - case ASTNodeDataType::bool_t: { - set_affectation_processor_for_data(bool{}, data_type); - break; - } - case ASTNodeDataType::unsigned_int_t: { - set_affectation_processor_for_data(uint64_t{}, data_type); - break; - } - case ASTNodeDataType::int_t: { - set_affectation_processor_for_data(int64_t{}, data_type); - break; - } - case ASTNodeDataType::double_t: { - set_affectation_processor_for_data(double{}, data_type); - break; - } - case ASTNodeDataType::vector_t: { - switch (value_type.dimension()) { - case 1: { - set_affectation_processor_for_vector_data(TinyVector<1>{}, data_type); - break; - } - case 2: { - set_affectation_processor_for_vector_data(TinyVector<2>{}, data_type); - break; - } - case 3: { - set_affectation_processor_for_vector_data(TinyVector<3>{}, data_type); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: unexpected vector dimension", std::vector{n.begin()}); - } - // LCOV_EXCL_STOP - } - break; - } - case ASTNodeDataType::string_t: { - set_affectation_processor_for_string_data(data_type); - break; - } - case ASTNodeDataType::type_id_t: { - set_affectation_processor_for_embedded_data(data_type); - break; - } - case ASTNodeDataType::tuple_t: { - const ASTNodeDataType& content_type = value_type.contentType(); - set_affectation_processor_for_tuple_data(content_type, data_type); - break; - } - default: { - throw ParseError("unexpected error: undefined value type for affectation", std::vector{n.children[0]->begin()}); - } - } - }; - - using OperatorT = std::decay_t<decltype(operator_v)>; - // Special treatment dedicated to R^1 to be able to initialize them - if (((n.m_data_type != n.children[1]->m_data_type) and (n.m_data_type == ASTNodeDataType::vector_t) and - (n.m_data_type.dimension() == 1)) or - // Special treatment for R^d vectors and operator *= - ((n.m_data_type == ASTNodeDataType::vector_t) and (n.children[1]->m_data_type != ASTNodeDataType::vector_t) and - std::is_same_v<OperatorT, language::multiplyeq_op>)) { - ASTNodeNaturalConversionChecker{*n.children[1], ASTNodeDataType::double_t}; + const ASTNodeDataType& target_data_type = n.children[0]->m_data_type; + const ASTNodeDataType& source_data_type = n.children[1]->m_data_type; + + const std::string affectation_name = [&] { + if (n.is_type<language::eq_op>()) { + return affectationMangler<language::eq_op>(target_data_type, source_data_type); + } else if (n.is_type<language::multiplyeq_op>()) { + return affectationMangler<language::multiplyeq_op>(target_data_type, source_data_type); + } else if (n.is_type<language::divideeq_op>()) { + return affectationMangler<language::divideeq_op>(target_data_type, source_data_type); + } else if (n.is_type<language::pluseq_op>()) { + return affectationMangler<language::pluseq_op>(target_data_type, source_data_type); + } else if (n.is_type<language::minuseq_op>()) { + return affectationMangler<language::minuseq_op>(target_data_type, source_data_type); } else { - ASTNodeNaturalConversionChecker{*n.children[1], n.m_data_type}; + throw ParseError("unexpected error: undefined affectation operator", std::vector{n.begin()}); } + }(); - set_affectation_processor_for_value(n.m_data_type); - }; + const auto& optional_processor_builder = + OperatorRepository::instance().getAffectationProcessorBuilder(affectation_name); - if (n.is_type<language::eq_op>()) { - set_affectation_processor(n, language::eq_op{}); - } else if (n.is_type<language::multiplyeq_op>()) { - set_affectation_processor(n, language::multiplyeq_op{}); - } else if (n.is_type<language::divideeq_op>()) { - set_affectation_processor(n, language::divideeq_op{}); - } else if (n.is_type<language::pluseq_op>()) { - set_affectation_processor(n, language::pluseq_op{}); - } else if (n.is_type<language::minuseq_op>()) { - set_affectation_processor(n, language::minuseq_op{}); + if (optional_processor_builder.has_value()) { + n.m_node_processor = optional_processor_builder.value()->getNodeProcessor(n); } else { - throw ParseError("unexpected error: undefined affectation operator", std::vector{n.begin()}); + std::ostringstream error_message; + error_message << "undefined affectation type: "; + error_message << rang::fgB::red << affectation_name << rang::fg::reset; + + throw ParseError(error_message.str(), std::vector{n.children[0]->begin()}); } } diff --git a/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp b/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp index 350a4070e9a6bc65906627d54b2e2ed87c403cc6..e717044d1dea5fe6b85aec6c41941699c1e5ef65 100644 --- a/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeArraySubscriptExpressionBuilder.cpp @@ -1,5 +1,6 @@ #include <language/ast/ASTNodeArraySubscriptExpressionBuilder.hpp> +#include <algebra/TinyMatrix.hpp> #include <algebra/TinyVector.hpp> #include <language/node_processor/ArraySubscriptProcessor.hpp> #include <language/utils/ParseError.hpp> @@ -27,6 +28,27 @@ ASTNodeArraySubscriptExpressionBuilder::ASTNodeArraySubscriptExpressionBuilder(A break; } } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + Assert(array_expression.m_data_type.nbRows() == array_expression.m_data_type.nbColumns()); + + switch (array_expression.m_data_type.nbRows()) { + case 1: { + node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyMatrix<1>>>(node); + break; + } + case 2: { + node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyMatrix<2>>>(node); + break; + } + case 3: { + node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyMatrix<3>>>(node); + break; + } + default: { + throw ParseError("unexpected error: invalid array dimension", array_expression.begin()); + break; + } + } } else { throw ParseError("unexpected error: invalid array type", array_expression.begin()); } diff --git a/src/language/ast/ASTNodeBinaryOperatorExpressionBuilder.cpp b/src/language/ast/ASTNodeBinaryOperatorExpressionBuilder.cpp index 64ff14deb191921c81e28ef1b3f36a0651ffc0e2..b6904220a9c42f2fcbfd49e8ce5881fbd9b7cf3a 100644 --- a/src/language/ast/ASTNodeBinaryOperatorExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBinaryOperatorExpressionBuilder.cpp @@ -3,205 +3,59 @@ #include <language/PEGGrammar.hpp> #include <language/node_processor/BinaryExpressionProcessor.hpp> #include <language/node_processor/ConcatExpressionProcessor.hpp> +#include <language/utils/BinaryOperatorMangler.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> ASTNodeBinaryOperatorExpressionBuilder::ASTNodeBinaryOperatorExpressionBuilder(ASTNode& n) { - auto set_binary_operator_processor = [](ASTNode& n, const auto& operator_v) { - auto set_binary_operator_processor_for_data_b = [&](const auto data_a, const ASTNodeDataType& data_type_b) { - using OperatorT = std::decay_t<decltype(operator_v)>; - using DataTA = std::decay_t<decltype(data_a)>; - - if constexpr (std::is_same_v<DataTA, std::string>) { - if constexpr (std::is_same_v<OperatorT, language::plus_op>) { - switch (data_type_b) { - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<ConcatExpressionProcessor<bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<ConcatExpressionProcessor<uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<ConcatExpressionProcessor<int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<ConcatExpressionProcessor<double>>(n); - break; - } - case ASTNodeDataType::string_t: { - n.m_node_processor = std::make_unique<ConcatExpressionProcessor<std::string>>(n); - break; - } - default: { - throw ParseError("undefined operand type for binary operator", std::vector{n.children[1]->begin()}); - } - } - - } else if constexpr ((std::is_same_v<OperatorT, language::eqeq_op>) or - (std::is_same_v<OperatorT, language::not_eq_op>)) { - if (data_type_b == ASTNodeDataType::string_t) { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, std::string>>(n); - } else { - throw ParseError("undefined operand type for binary operator", std::vector{n.begin()}); - } - } else { - throw ParseError("undefined operand type for binary operator", std::vector{n.begin()}); - } - } else if constexpr (std::is_same_v<DataTA, TinyVector<1>> or std::is_same_v<DataTA, TinyVector<2>> or - std::is_same_v<DataTA, TinyVector<3>>) { - if ((data_type_b == ASTNodeDataType::vector_t)) { - if constexpr (std::is_same_v<OperatorT, language::plus_op> or std::is_same_v<OperatorT, language::minus_op> or - std::is_same_v<OperatorT, language::eqeq_op> or - std::is_same_v<OperatorT, language::not_eq_op>) { - if (data_a.dimension() == data_type_b.dimension()) { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, DataTA>>(n); - } else { - throw ParseError("incompatible dimensions of operands", std::vector{n.begin()}); - } - } else { - throw ParseError("invalid binary operator", std::vector{n.begin()}); - } - } else { - // LCOV_EXCL_START - throw ParseError("unexpected error: invalid operand type for binary operator", - std::vector{n.children[1]->begin()}); - // LCOV_EXCL_STOP - } - } else { - switch (data_type_b) { - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, double>>(n); - break; - } - case ASTNodeDataType::vector_t: { - if constexpr (std::is_same_v<OperatorT, language::multiply_op>) { - switch (data_type_b.dimension()) { - case 1: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, TinyVector<1>>>(n); - break; - } - case 2: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, TinyVector<2>>>(n); - break; - } - case 3: { - n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, TinyVector<3>>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid dimension", std::vector{n.children[0]->begin()}); - } - // LCOV_EXCL_STOP - } - break; - } - } - default: { - throw ParseError("undefined operand type for binary operator", std::vector{n.children[1]->begin()}); - } - } - } - }; - - auto set_binary_operator_processor_for_data_a = [&](const ASTNodeDataType& data_type_a) { - const ASTNodeDataType data_type_b = n.children[1]->m_data_type; - switch (data_type_a) { - case ASTNodeDataType::bool_t: { - set_binary_operator_processor_for_data_b(bool{}, data_type_b); - break; - } - case ASTNodeDataType::unsigned_int_t: { - set_binary_operator_processor_for_data_b(uint64_t{}, data_type_b); - break; - } - case ASTNodeDataType::int_t: { - set_binary_operator_processor_for_data_b(int64_t{}, data_type_b); - break; - } - case ASTNodeDataType::double_t: { - set_binary_operator_processor_for_data_b(double{}, data_type_b); - break; - } - case ASTNodeDataType::string_t: { - set_binary_operator_processor_for_data_b(std::string{}, data_type_b); - break; - } - case ASTNodeDataType::vector_t: { - switch (data_type_a.dimension()) { - case 1: { - set_binary_operator_processor_for_data_b(TinyVector<1>{}, data_type_b); - break; - } - case 2: { - set_binary_operator_processor_for_data_b(TinyVector<2>{}, data_type_b); - break; - } - case 3: { - set_binary_operator_processor_for_data_b(TinyVector<3>{}, data_type_b); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid dimension", std::vector{n.children[0]->begin()}); - } - // LCOV_EXCL_STOP - } - break; - } - default: { - throw ParseError("undefined operand type for binary operator", std::vector{n.children[0]->begin()}); - } - } - }; - - set_binary_operator_processor_for_data_a(n.children[0]->m_data_type); - }; - - if (n.is_type<language::multiply_op>()) { - set_binary_operator_processor(n, language::multiply_op{}); - } else if (n.is_type<language::divide_op>()) { - set_binary_operator_processor(n, language::divide_op{}); - } else if (n.is_type<language::plus_op>()) { - set_binary_operator_processor(n, language::plus_op{}); - } else if (n.is_type<language::minus_op>()) { - set_binary_operator_processor(n, language::minus_op{}); - - } else if (n.is_type<language::or_op>()) { - set_binary_operator_processor(n, language::or_op{}); - } else if (n.is_type<language::and_op>()) { - set_binary_operator_processor(n, language::and_op{}); - } else if (n.is_type<language::xor_op>()) { - set_binary_operator_processor(n, language::xor_op{}); - - } else if (n.is_type<language::greater_op>()) { - set_binary_operator_processor(n, language::greater_op{}); - } else if (n.is_type<language::greater_or_eq_op>()) { - set_binary_operator_processor(n, language::greater_or_eq_op{}); - } else if (n.is_type<language::lesser_op>()) { - set_binary_operator_processor(n, language::lesser_op{}); - } else if (n.is_type<language::lesser_or_eq_op>()) { - set_binary_operator_processor(n, language::lesser_or_eq_op{}); - } else if (n.is_type<language::eqeq_op>()) { - set_binary_operator_processor(n, language::eqeq_op{}); - } else if (n.is_type<language::not_eq_op>()) { - set_binary_operator_processor(n, language::not_eq_op{}); + const ASTNodeDataType& lhs_data_type = n.children[0]->m_data_type; + const ASTNodeDataType& rhs_data_type = n.children[1]->m_data_type; + + const std::string binary_operator_name = [&]() -> std::string { + if (n.is_type<language::multiply_op>()) { + return binaryOperatorMangler<language::multiply_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::divide_op>()) { + return binaryOperatorMangler<language::divide_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::plus_op>()) { + return binaryOperatorMangler<language::plus_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::minus_op>()) { + return binaryOperatorMangler<language::minus_op>(lhs_data_type, rhs_data_type); + + } else if (n.is_type<language::or_op>()) { + return binaryOperatorMangler<language::or_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::and_op>()) { + return binaryOperatorMangler<language::and_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::xor_op>()) { + return binaryOperatorMangler<language::xor_op>(lhs_data_type, rhs_data_type); + + } else if (n.is_type<language::greater_op>()) { + return binaryOperatorMangler<language::greater_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::greater_or_eq_op>()) { + return binaryOperatorMangler<language::greater_or_eq_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::lesser_op>()) { + return binaryOperatorMangler<language::lesser_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::lesser_or_eq_op>()) { + return binaryOperatorMangler<language::lesser_or_eq_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::eqeq_op>()) { + return binaryOperatorMangler<language::eqeq_op>(lhs_data_type, rhs_data_type); + } else if (n.is_type<language::not_eq_op>()) { + return binaryOperatorMangler<language::not_eq_op>(lhs_data_type, rhs_data_type); + } else { + throw ParseError("unexpected error: undefined binary operator", std::vector{n.begin()}); + } + }(); + + const auto& optional_processor_builder = + OperatorRepository::instance().getBinaryProcessorBuilder(binary_operator_name); + + if (optional_processor_builder.has_value()) { + n.m_node_processor = optional_processor_builder.value()->getNodeProcessor(n); } else { - throw ParseError("unexpected error: undefined binary operator", std::vector{n.begin()}); + std::ostringstream error_message; + error_message << "undefined binary operator type: "; + error_message << rang::fgB::red << binary_operator_name << rang::fg::reset; + + throw ParseError(error_message.str(), std::vector{n.children[0]->begin()}); } } diff --git a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 6dc726f555219b855e2a125b0ab22b586dec2dad..248511e6c91c2eec08a94e87faa8179c06a0bf80 100644 --- a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -2,8 +2,8 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/BuiltinFunctionProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/ParseError.hpp> #include <language/utils/SymbolTable.hpp> @@ -112,6 +112,84 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } }; + auto get_function_argument_converter_for_matrix = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + + if constexpr (std::is_same_v<ParameterT, TinyMatrix<1>>) { + switch (argument_node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((argument_node_sub_data_type.m_data_type.nbRows() == 1) and + (argument_node_sub_data_type.m_data_type.nbColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::bool_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, bool>>(argument_number); + } + case ASTNodeDataType::int_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, int64_t>>(argument_number); + } + case ASTNodeDataType::unsigned_int_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, uint64_t>>(argument_number); + } + case ASTNodeDataType::double_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, double>>(argument_number); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } else { + switch (argument_node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((argument_node_sub_data_type.m_data_type.nbRows() == parameter_v.nbRows()) and + (argument_node_sub_data_type.m_data_type.nbColumns() == parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (argument_node_sub_data_type.m_parent_node.children.size() == + (parameter_v.nbRows() * parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (argument_node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(argument_node_sub_data_type.m_parent_node.string()) == 0) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(argument_number); + } + } + [[fallthrough]]; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } + }; + auto get_function_argument_to_string_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { return std::make_unique<FunctionArgumentToStringConverter>(argument_number); }; @@ -195,6 +273,25 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } // LCOV_EXCL_STOP } + } + case ASTNodeDataType::matrix_t: { + Assert(arg_data_type.nbRows() == arg_data_type.nbColumns()); + switch (arg_data_type.nbRows()) { + case 1: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<1>>>(argument_number); + } + case 2: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<2>>>(argument_number); + } + case 3: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<3>>>(argument_number); + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of vector"); + } + // LCOV_EXCL_STOP + } } // LCOV_EXCL_START default: { @@ -254,6 +351,26 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(parameter_type.nbRows() == parameter_type.nbColumns()); + switch (parameter_type.nbRows()) { + case 1: { + return get_function_argument_converter_for_matrix(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_converter_for_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_converter_for_matrix(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: undefined parameter type for function", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_argument_to_string_converter(); } @@ -281,10 +398,37 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData return get_function_argument_to_tuple_converter(double{}); } case ASTNodeDataType::vector_t: { - switch (parameter_type.dimension()) { + switch (parameter_type.contentType().dimension()) { case 1: { return get_function_argument_to_tuple_converter(TinyVector<1>{}); } + case 2: { + return get_function_argument_to_tuple_converter(TinyVector<2>{}); + } + case 3: { + return get_function_argument_to_tuple_converter(TinyVector<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: unexpected tuple content for function: '" + dataTypeName(parameter_type) + + "'", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::matrix_t: { + Assert(parameter_type.contentType().nbRows() == parameter_type.contentType().nbColumns()); + switch (parameter_type.contentType().nbRows()) { + case 1: { + return get_function_argument_to_tuple_converter(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_to_tuple_converter(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_to_tuple_converter(TinyMatrix<3>{}); + } // LCOV_EXCL_START default: { throw ParseError("unexpected error: unexpected tuple content for function: '" + dataTypeName(parameter_type) + @@ -314,13 +458,7 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } }; - if (parameter_type == ASTNodeDataType::vector_t and parameter_type.dimension() == 1) { - if (not isNaturalConversion(argument_node_sub_data_type.m_data_type, parameter_type)) { - ASTNodeNaturalConversionChecker{argument_node_sub_data_type, ASTNodeDataType::double_t}; - } - } else { - ASTNodeNaturalConversionChecker{argument_node_sub_data_type, parameter_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{argument_node_sub_data_type, parameter_type}; return get_function_argument_converter_for_argument_type(); } diff --git a/src/language/ast/ASTNodeDataType.hpp b/src/language/ast/ASTNodeDataType.hpp deleted file mode 100644 index 863a87dad0a3df9d3659b9c661336fbf5e384aca..0000000000000000000000000000000000000000 --- a/src/language/ast/ASTNodeDataType.hpp +++ /dev/null @@ -1,106 +0,0 @@ -#ifndef AST_NODE_DATA_TYPE_HPP -#define AST_NODE_DATA_TYPE_HPP - -#include <utils/PugsAssert.hpp> - -#include <limits> -#include <memory> -#include <string> - -class ASTNode; - -class ASTNodeDataType -{ - public: - enum DataType : int32_t - { - undefined_t = -1, - bool_t = 0, - int_t = 1, - unsigned_int_t = 2, - double_t = 3, - vector_t = 4, - tuple_t = 5, - list_t = 6, - string_t = 7, - typename_t = 10, - type_name_id_t = 11, - type_id_t = 21, - function_t = 22, - builtin_function_t = 23, - void_t = std::numeric_limits<int32_t>::max() - }; - - private: - DataType m_data_type; - std::shared_ptr<ASTNodeDataType> m_content_type; - size_t m_dimension; - std::string m_name_of_type_id; - - public: - PUGS_INLINE - size_t - dimension() const - { - return m_dimension; - } - - PUGS_INLINE - const ASTNodeDataType& - contentType() const - { - Assert(m_content_type); - return *m_content_type; - } - - PUGS_INLINE - const std::string& - nameOfTypeId() const - { - return m_name_of_type_id; - } - - PUGS_INLINE - operator const DataType&() const - { - return m_data_type; - } - - ASTNodeDataType& operator=(const ASTNodeDataType&) = default; - ASTNodeDataType& operator=(ASTNodeDataType&&) = default; - - ASTNodeDataType(DataType data_type) - : m_data_type{data_type}, m_content_type{nullptr}, m_dimension{1}, m_name_of_type_id{"unknown"} - {} - - ASTNodeDataType(DataType data_type, const ASTNodeDataType& content_type) - : m_data_type{data_type}, - m_content_type{std::make_shared<ASTNodeDataType>(content_type)}, - m_dimension{1}, - m_name_of_type_id{"unknown"} - {} - - ASTNodeDataType(DataType data_type, size_t dimension) - : m_data_type{data_type}, m_content_type{nullptr}, m_dimension{dimension}, m_name_of_type_id{"unknown"} - {} - - ASTNodeDataType(DataType data_type, const std::string& type_name) - : m_data_type{data_type}, m_content_type{nullptr}, m_dimension{1}, m_name_of_type_id{type_name} - {} - - ASTNodeDataType(const ASTNodeDataType&) = default; - - ASTNodeDataType(ASTNodeDataType&&) = default; - - ~ASTNodeDataType() = default; -}; - -ASTNodeDataType getVectorDataType(const ASTNode& type_node); - -std::string dataTypeName(const ASTNodeDataType& data_type); - -ASTNodeDataType dataTypePromotion(const ASTNodeDataType& data_type_1, const ASTNodeDataType& data_type_2); - -bool isNaturalConversion(const ASTNodeDataType& data_type, const ASTNodeDataType& target_data_type); - -#endif // AST_NODE_DATA_TYPE_HPP diff --git a/src/language/ast/ASTNodeDataTypeBuilder.cpp b/src/language/ast/ASTNodeDataTypeBuilder.cpp index 7da94cc5f7d00132b226f8ba0e7e5d95d4097b35..018bc94f74160bfcd680df08ac9b2a08c54ba492 100644 --- a/src/language/ast/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ast/ASTNodeDataTypeBuilder.cpp @@ -1,43 +1,48 @@ #include <language/ast/ASTNodeDataTypeBuilder.hpp> #include <language/PEGGrammar.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> #include <language/utils/SymbolTable.hpp> #include <utils/PugsAssert.hpp> -ASTNodeDataType +void ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNode& name_node) const { - ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; + ASTNodeDataType data_type; if (type_node.is_type<language::type_expression>()) { if (type_node.children.size() != name_node.children.size()) { std::ostringstream message; message << "number of product spaces (" << type_node.children.size() << ") " << rang::fgB::yellow << type_node.string() << rang::style::reset << rang::style::bold << " differs from number of variables (" - << name_node.children.size() << ") " << rang::fgB::yellow << name_node.string() << rang::style::reset - << std::ends; + << name_node.children.size() << ") " << rang::fgB::yellow << name_node.string() << rang::style::reset; throw ParseError(message.str(), name_node.begin()); } + std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; + sub_data_type_list.reserve(type_node.children.size()); for (size_t i = 0; i < type_node.children.size(); ++i) { auto& sub_type_node = *type_node.children[i]; auto& sub_name_node = *name_node.children[i]; _buildDeclarationNodeDataTypes(sub_type_node, sub_name_node); + sub_data_type_list.push_back(std::make_shared<const ASTNodeDataType>(sub_type_node.m_data_type)); } - data_type = ASTNodeDataType::typename_t; + data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>(sub_data_type_list); } else { if (type_node.is_type<language::B_set>()) { - data_type = ASTNodeDataType::bool_t; + data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); } else if (type_node.is_type<language::Z_set>()) { - data_type = ASTNodeDataType::int_t; + data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); } else if (type_node.is_type<language::N_set>()) { - data_type = ASTNodeDataType::unsigned_int_t; + data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); } else if (type_node.is_type<language::R_set>()) { - data_type = ASTNodeDataType::double_t; + data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (type_node.is_type<language::vector_type>()) { data_type = getVectorDataType(type_node); + } else if (type_node.is_type<language::matrix_type>()) { + data_type = getMatrixDataType(type_node); } else if (type_node.is_type<language::tuple_type_specifier>()) { const auto& content_node = type_node.children[0]; @@ -52,32 +57,34 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) { std::ostringstream os; os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '" - << dataTypeName(i_type_symbol->attributes().dataType()) << "'" << std::ends; + << dataTypeName(i_type_symbol->attributes().dataType()) << '\''; throw ParseError(os.str(), std::vector{content_node->begin()}); } - content_node->m_data_type = ASTNodeDataType{ASTNodeDataType::type_id_t, type_name_id}; + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::type_id_t>(type_name_id); } else if (content_node->is_type<language::B_set>()) { - content_node->m_data_type = ASTNodeDataType::bool_t; + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); } else if (content_node->is_type<language::Z_set>()) { - content_node->m_data_type = ASTNodeDataType::int_t; + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); } else if (content_node->is_type<language::N_set>()) { - content_node->m_data_type = ASTNodeDataType::unsigned_int_t; + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); } else if (content_node->is_type<language::R_set>()) { - content_node->m_data_type = ASTNodeDataType::double_t; + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (content_node->is_type<language::vector_type>()) { content_node->m_data_type = getVectorDataType(*type_node.children[0]); + } else if (content_node->is_type<language::matrix_type>()) { + content_node->m_data_type = getMatrixDataType(*type_node.children[0]); } else if (content_node->is_type<language::string_type>()) { - content_node->m_data_type = ASTNodeDataType::string_t; + content_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else { // LCOV_EXCL_START throw UnexpectedError("unexpected content type in tuple"); // LCOV_EXCL_STOP } - data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, content_node->m_data_type}; + data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(content_node->m_data_type); } else if (type_node.is_type<language::string_type>()) { - data_type = ASTNodeDataType::string_t; + data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else if (type_node.is_type<language::type_name_id>()) { const std::string& type_name_id = type_node.string(); @@ -89,11 +96,11 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) { std::ostringstream os; os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '" - << dataTypeName(i_type_symbol->attributes().dataType()) << "'" << std::ends; + << dataTypeName(i_type_symbol->attributes().dataType()) << '\''; throw ParseError(os.str(), std::vector{type_node.begin()}); } - data_type = ASTNodeDataType{ASTNodeDataType::type_id_t, type_name_id}; + data_type = ASTNodeDataType::build<ASTNodeDataType::type_id_t>(type_name_id); } if (name_node.is_type<language::name_list>()) { @@ -113,7 +120,7 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo } Assert(data_type != ASTNodeDataType::undefined_t); - return data_type; + type_node.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(data_type); } void @@ -128,51 +135,53 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const const ASTNode& test_node = *n.children[1]; if (not n.children[1]->is_type<language::for_test>()) { - ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::bool_t}; + ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::build<ASTNodeDataType::bool_t>()}; } // in the case of empty for_test (not simplified node), nothing to check! } - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else { if (n.has_content()) { if (n.is_type<language::import_instruction>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::module_name>()) { - n.m_data_type = ASTNodeDataType::string_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else if (n.is_type<language::true_kw>() or n.is_type<language::false_kw>()) { - n.m_data_type = ASTNodeDataType::bool_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); } else if (n.is_type<language::real>()) { - n.m_data_type = ASTNodeDataType::double_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); } else if (n.is_type<language::integer>()) { - n.m_data_type = ASTNodeDataType::int_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); } else if (n.is_type<language::vector_type>()) { n.m_data_type = getVectorDataType(n); - - } else if (n.is_type<language::tuple_expression>()) { - n.m_data_type = ASTNodeDataType::list_t; + } else if (n.is_type<language::matrix_type>()) { + n.m_data_type = getMatrixDataType(n); } else if (n.is_type<language::literal>()) { - n.m_data_type = ASTNodeDataType::string_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::var_declaration>()) { auto& name_node = *(n.children[0]); auto& type_node = *(n.children[1]); - type_node.m_data_type = _buildDeclarationNodeDataTypes(type_node, name_node); - n.m_data_type = type_node.m_data_type; + _buildDeclarationNodeDataTypes(type_node, name_node); + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::fct_declaration>()) { - n.children[0]->m_data_type = ASTNodeDataType::function_t; + n.children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::function_t>(); const std::string& symbol = n.children[0]->string(); - auto [i_symbol, success] = n.m_symbol_table->find(symbol, n.children[0]->begin()); + + auto [i_symbol, success] = n.m_symbol_table->find(symbol, n.children[0]->begin()); auto& function_table = n.m_symbol_table->functionTable(); uint64_t function_id = std::get<uint64_t>(i_symbol->attributes().value()); FunctionDescriptor& function_descriptor = function_table[function_id]; + this->_buildNodeDataTypes(function_descriptor.domainMappingNode()); + ASTNode& parameters_domain_node = *function_descriptor.domainMappingNode().children[0]; ASTNode& parameters_name_node = *function_descriptor.definitionNode().children[0]; @@ -196,34 +205,15 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const message << "note: number of product spaces (" << nb_parameter_domains << ") " << rang::fgB::yellow << parameters_domain_node.string() << rang::style::reset << rang::style::bold << " differs from number of variables (" << nb_parameter_names << ") " << rang::fgB::yellow - << parameters_name_node.string() << rang::style::reset << std::ends; + << parameters_name_node.string() << rang::style::reset; throw ParseError(message.str(), parameters_domain_node.begin()); } auto simple_type_allocator = [&](const ASTNode& type_node, ASTNode& symbol_node) { Assert(symbol_node.is_type<language::name>()); - ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; - if (type_node.is_type<language::B_set>()) { - data_type = ASTNodeDataType::bool_t; - } else if (type_node.is_type<language::Z_set>()) { - data_type = ASTNodeDataType::int_t; - } else if (type_node.is_type<language::N_set>()) { - data_type = ASTNodeDataType::unsigned_int_t; - } else if (type_node.is_type<language::R_set>()) { - data_type = ASTNodeDataType::double_t; - } else if (type_node.is_type<language::vector_type>()) { - data_type = getVectorDataType(type_node); - } else if (type_node.is_type<language::string_type>()) { - data_type = ASTNodeDataType::string_t; - } - - // LCOV_EXCL_START - if (data_type == ASTNodeDataType::undefined_t) { - throw ParseError("invalid parameter type", type_node.begin()); - } - // LCOV_EXCL_STOP + const ASTNodeDataType& data_type = type_node.m_data_type.contentType(); - symbol_node.m_data_type = data_type; + symbol_node.m_data_type = type_node.m_data_type.contentType(); const std::string& symbol = symbol_node.string(); std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table; @@ -236,28 +226,22 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const if (nb_parameter_domains == 1) { simple_type_allocator(parameters_domain_node, parameters_name_node); } else { + std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; + sub_data_type_list.reserve(nb_parameter_domains); + for (size_t i = 0; i < nb_parameter_domains; ++i) { simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]); + sub_data_type_list.push_back( + std::make_shared<const ASTNodeDataType>(parameters_name_node.children[i]->m_data_type)); } - parameters_name_node.m_data_type = ASTNodeDataType::list_t; + parameters_name_node.m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>(sub_data_type_list); } - // build types for compound types - for (auto& child : parameters_domain_node.children) { - this->_buildNodeDataTypes(*child); - } - for (auto& child : parameters_name_node.children) { - this->_buildNodeDataTypes(*child); - } + this->_buildNodeDataTypes(function_descriptor.definitionNode()); ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1]; ASTNode& image_expression_node = *function_descriptor.definitionNode().children[1]; - this->_buildNodeDataTypes(image_domain_node); - for (auto& child : image_domain_node.children) { - this->_buildNodeDataTypes(*child); - } - const size_t nb_image_domains = (image_domain_node.is_type<language::type_expression>()) ? image_domain_node.children.size() : 1; const size_t nb_image_expressions = @@ -269,8 +253,15 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const if (image_type.dimension() != nb_image_expressions) { std::ostringstream message; message << "expecting " << image_type.dimension() << " scalar expressions or an " - << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions" - << std::ends; + << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions"; + throw ParseError(message.str(), image_domain_node.begin()); + } + } else if (image_domain_node.is_type<language::matrix_type>()) { + ASTNodeDataType image_type = getMatrixDataType(image_domain_node); + if (image_type.nbRows() * image_type.nbColumns() != nb_image_expressions) { + std::ostringstream message; + message << "expecting " << image_type.nbRows() * image_type.nbColumns() << " scalar expressions or an " + << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions"; throw ParseError(message.str(), image_domain_node.begin()); } } else { @@ -278,44 +269,16 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const message << "number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow << image_domain_node.string() << rang::style::reset << rang::style::bold << " differs from number of expressions (" << nb_image_expressions << ") " << rang::fgB::yellow - << image_expression_node.string() << rang::style::reset << std::ends; + << image_expression_node.string() << rang::style::reset; throw ParseError(message.str(), image_domain_node.begin()); } } - auto check_image_type = [&](const ASTNode& image_node) { - ASTNodeDataType value_type{ASTNodeDataType::undefined_t}; - if (image_node.is_type<language::B_set>()) { - value_type = ASTNodeDataType::bool_t; - } else if (image_node.is_type<language::Z_set>()) { - value_type = ASTNodeDataType::int_t; - } else if (image_node.is_type<language::N_set>()) { - value_type = ASTNodeDataType::unsigned_int_t; - } else if (image_node.is_type<language::R_set>()) { - value_type = ASTNodeDataType::double_t; - } else if (image_node.is_type<language::vector_type>()) { - value_type = getVectorDataType(image_node); - } else if (image_node.is_type<language::string_type>()) { - value_type = ASTNodeDataType::string_t; - } + this->_buildNodeDataTypes(image_expression_node); - // LCOV_EXCL_START - if (value_type == ASTNodeDataType::undefined_t) { - throw ParseError("invalid value type", image_node.begin()); - } - // LCOV_EXCL_STOP - }; + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(image_expression_node, image_domain_node.m_data_type); - if (image_domain_node.is_type<language::type_expression>()) { - for (size_t i = 0; i < image_domain_node.children.size(); ++i) { - check_image_type(*image_domain_node.children[i]); - } - image_domain_node.m_data_type = ASTNodeDataType::typename_t; - } else { - check_image_type(image_domain_node); - } - - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); return; } else if (n.is_type<language::name>()) { std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table; @@ -330,81 +293,214 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const } if (n.is_type<language::break_kw>() or n.is_type<language::continue_kw>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::eq_op>() or n.is_type<language::multiplyeq_op>() or n.is_type<language::divideeq_op>() or n.is_type<language::pluseq_op>() or n.is_type<language::minuseq_op>()) { - n.m_data_type = n.children[0]->m_data_type; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + + } else if (n.is_type<language::tuple_expression>()) { + std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; + sub_data_type_list.reserve(n.children.size()); + + for (size_t i = 0; i < n.children.size(); ++i) { + sub_data_type_list.push_back(std::make_shared<const ASTNodeDataType>(n.children[i]->m_data_type)); + } + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>(sub_data_type_list); + } else if (n.is_type<language::type_mapping>() or n.is_type<language::function_definition>()) { - n.m_data_type = ASTNodeDataType::void_t; + for (auto& child : n.children) { + this->_buildNodeDataTypes(*child); + } + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + } else if (n.is_type<language::type_expression>()) { + std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; + sub_data_type_list.reserve(n.children.size()); + + auto check_sub_type = [&](const ASTNode& image_node) { + ASTNodeDataType value_type; + if (image_node.is_type<language::B_set>()) { + value_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + } else if (image_node.is_type<language::Z_set>()) { + value_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + } else if (image_node.is_type<language::N_set>()) { + value_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + } else if (image_node.is_type<language::R_set>()) { + value_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + } else if (image_node.is_type<language::vector_type>()) { + value_type = getVectorDataType(image_node); + } else if (image_node.is_type<language::matrix_type>()) { + value_type = getMatrixDataType(image_node); + } else if (image_node.is_type<language::string_type>()) { + value_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + } + + // LCOV_EXCL_START + if (value_type == ASTNodeDataType::undefined_t) { + throw ParseError("invalid value type", image_node.begin()); + } + // LCOV_EXCL_STOP + }; + + for (size_t i = 0; i < n.children.size(); ++i) { + check_sub_type(*n.children[i]); + sub_data_type_list.push_back(std::make_shared<const ASTNodeDataType>(n.children[i]->m_data_type)); + } + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>( + ASTNodeDataType::build<ASTNodeDataType::list_t>(sub_data_type_list)); + } else if (n.is_type<language::for_post>() or n.is_type<language::for_init>() or n.is_type<language::for_statement_block>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::for_test>()) { - n.m_data_type = ASTNodeDataType::bool_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); } else if (n.is_type<language::statement_block>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); } else if (n.is_type<language::if_statement>() or n.is_type<language::while_statement>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); const ASTNode& test_node = *n.children[0]; - ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::bool_t}; + ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::build<ASTNodeDataType::bool_t>()}; } else if (n.is_type<language::do_while_statement>()) { - n.m_data_type = ASTNodeDataType::void_t; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); const ASTNode& test_node = *n.children[1]; - ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::bool_t}; - - } else if (n.is_type<language::unary_not>()) { - n.m_data_type = ASTNodeDataType::bool_t; - - const ASTNode& operand_node = *n.children[0]; - ASTNodeNaturalConversionChecker{operand_node, ASTNodeDataType::bool_t}; + ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::build<ASTNodeDataType::bool_t>()}; + + } else if (n.is_type<language::unary_not>() or n.is_type<language::unary_minus>()) { + auto& operator_repository = OperatorRepository::instance(); + + auto optional_value_type = [&] { + if (n.is_type<language::unary_not>()) { + return operator_repository.getUnaryOperatorValueType( + unaryOperatorMangler<language::unary_not>(n.children[0]->m_data_type)); + } else if (n.is_type<language::unary_minus>()) { + return operator_repository.getUnaryOperatorValueType( + unaryOperatorMangler<language::unary_minus>(n.children[0]->m_data_type)); + } else { + // LCOV_EXCL_START + throw UnexpectedError("invalid unary operator type"); + // LCOV_EXCL_STOP + } + }(); - } else if (n.is_type<language::lesser_op>() or n.is_type<language::lesser_or_eq_op>() or - n.is_type<language::greater_op>() or n.is_type<language::greater_or_eq_op>() or - n.is_type<language::eqeq_op>() or n.is_type<language::not_eq_op>()) { - n.m_data_type = ASTNodeDataType::bool_t; - } else if (n.is_type<language::and_op>() or n.is_type<language::or_op>() or n.is_type<language::xor_op>()) { - n.m_data_type = ASTNodeDataType::bool_t; - - const ASTNode& lhs_node = *n.children[0]; - ASTNodeNaturalConversionChecker{lhs_node, ASTNodeDataType::bool_t}; - - const ASTNode& rhs_node = *n.children[1]; - ASTNodeNaturalConversionChecker{rhs_node, ASTNodeDataType::bool_t}; - - } else if (n.is_type<language::unary_minus>()) { - n.m_data_type = n.children[0]->m_data_type; - if ((n.children[0]->m_data_type == ASTNodeDataType::unsigned_int_t) or - (n.children[0]->m_data_type == ASTNodeDataType::bool_t)) { - n.m_data_type = ASTNodeDataType::int_t; + if (optional_value_type.has_value()) { + n.m_data_type = optional_value_type.value(); } else { - n.m_data_type = n.children[0]->m_data_type; + std::ostringstream message; + message << "undefined unary operator\n" + << "note: unexpected operand type " << rang::fgB::red << dataTypeName(n.children[0]->m_data_type) + << rang::style::reset; + throw ParseError(message.str(), n.begin()); } + } else if (n.is_type<language::unary_plusplus>() or n.is_type<language::unary_minusminus>() or n.is_type<language::post_plusplus>() or n.is_type<language::post_minusminus>()) { - n.m_data_type = n.children[0]->m_data_type; + auto& operator_repository = OperatorRepository::instance(); + + auto optional_value_type = [&] { + if (n.is_type<language::unary_plusplus>()) { + return operator_repository.getIncDecOperatorValueType( + incDecOperatorMangler<language::unary_plusplus>(n.children[0]->m_data_type)); + } else if (n.is_type<language::unary_minusminus>()) { + return operator_repository.getIncDecOperatorValueType( + incDecOperatorMangler<language::unary_minusminus>(n.children[0]->m_data_type)); + } else if (n.is_type<language::post_minusminus>()) { + return operator_repository.getIncDecOperatorValueType( + incDecOperatorMangler<language::post_minusminus>(n.children[0]->m_data_type)); + } else if (n.is_type<language::post_plusplus>()) { + return operator_repository.getIncDecOperatorValueType( + incDecOperatorMangler<language::post_plusplus>(n.children[0]->m_data_type)); + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected operator type"); + // LCOV_EXCL_STOP + } + }(); + + if (optional_value_type.has_value()) { + n.m_data_type = optional_value_type.value(); + } else { + std::ostringstream message; + message << "undefined increment/decrement operator\n" + << "note: unexpected operand type " << rang::fgB::red << dataTypeName(n.children[0]->m_data_type) + << rang::style::reset; + throw ParseError(message.str(), n.begin()); + } + } else if (n.is_type<language::plus_op>() or n.is_type<language::minus_op>() or - n.is_type<language::multiply_op>() or n.is_type<language::divide_op>()) { + n.is_type<language::multiply_op>() or n.is_type<language::divide_op>() or + n.is_type<language::lesser_op>() or n.is_type<language::lesser_or_eq_op>() or + n.is_type<language::greater_op>() or n.is_type<language::greater_or_eq_op>() or + n.is_type<language::eqeq_op>() or n.is_type<language::not_eq_op>() or n.is_type<language::and_op>() or + n.is_type<language::or_op>() or n.is_type<language::xor_op>()) { const ASTNodeDataType type_0 = n.children[0]->m_data_type; const ASTNodeDataType type_1 = n.children[1]->m_data_type; - if ((type_0 == ASTNodeDataType::bool_t) and (type_1 == ASTNodeDataType::bool_t)) { - n.m_data_type = ASTNodeDataType::int_t; + auto& operator_repository = OperatorRepository::instance(); + + auto optional_value_type = [&] { + if (n.is_type<language::plus_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::plus_op>(type_0, type_1)); + } else if (n.is_type<language::minus_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::minus_op>(type_0, type_1)); + } else if (n.is_type<language::multiply_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::multiply_op>(type_0, type_1)); + } else if (n.is_type<language::divide_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::divide_op>(type_0, type_1)); + + } else if (n.is_type<language::lesser_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::lesser_op>(type_0, type_1)); + } else if (n.is_type<language::lesser_or_eq_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::lesser_or_eq_op>(type_0, type_1)); + } else if (n.is_type<language::greater_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::greater_op>(type_0, type_1)); + } else if (n.is_type<language::greater_or_eq_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::greater_or_eq_op>(type_0, type_1)); + } else if (n.is_type<language::eqeq_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::eqeq_op>(type_0, type_1)); + } else if (n.is_type<language::not_eq_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::not_eq_op>(type_0, type_1)); + + } else if (n.is_type<language::and_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::and_op>(type_0, type_1)); + } else if (n.is_type<language::or_op>()) { + return operator_repository.getBinaryOperatorValueType(binaryOperatorMangler<language::or_op>(type_0, type_1)); + } else if (n.is_type<language::xor_op>()) { + return operator_repository.getBinaryOperatorValueType( + binaryOperatorMangler<language::xor_op>(type_0, type_1)); + + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected operator type"); + // LCOV_EXCL_STOP + } + }(); + + if (optional_value_type.has_value()) { + n.m_data_type = optional_value_type.value(); } else { - n.m_data_type = dataTypePromotion(type_0, type_1); - } - if (n.m_data_type == ASTNodeDataType::undefined_t) { std::ostringstream message; message << "undefined binary operator\n" - << "note: incompatible operand types " << n.children[0]->string() << " (" << dataTypeName(type_0) - << ") and " << n.children[1]->string() << " (" << dataTypeName(type_1) << ')' << std::ends; + << "note: incompatible operand types " << dataTypeName(type_0) << " and " << dataTypeName(type_1); throw ParseError(message.str(), n.begin()); } + } else if (n.is_type<language::function_evaluation>()) { if (n.children[0]->m_data_type == ASTNodeDataType::function_t) { - const std::string& function_name = n.children[0]->string(); + const std::string& function_name = n.children[0]->string(); + auto [i_function_symbol, success] = n.m_symbol_table->find(function_name, n.children[0]->begin()); auto& function_table = n.m_symbol_table->functionTable(); @@ -414,28 +510,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1]; - ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; - if (image_domain_node.is_type<language::type_expression>()) { - data_type = image_domain_node.m_data_type; - } else { - if (image_domain_node.is_type<language::B_set>()) { - data_type = ASTNodeDataType::bool_t; - } else if (image_domain_node.is_type<language::Z_set>()) { - data_type = ASTNodeDataType::int_t; - } else if (image_domain_node.is_type<language::N_set>()) { - data_type = ASTNodeDataType::unsigned_int_t; - } else if (image_domain_node.is_type<language::R_set>()) { - data_type = ASTNodeDataType::double_t; - } else if (image_domain_node.is_type<language::vector_type>()) { - data_type = getVectorDataType(image_domain_node); - } else if (image_domain_node.is_type<language::string_type>()) { - data_type = ASTNodeDataType::string_t; - } - } - - Assert(data_type != ASTNodeDataType::undefined_t); // LCOV_EXCL_LINE - - n.m_data_type = data_type; + n.m_data_type = image_domain_node.m_data_type.contentType(); } else if (n.children[0]->m_data_type == ASTNodeDataType::builtin_function_t) { const std::string builtin_function_name = n.children[0]->string(); auto& symbol_table = *n.m_symbol_table; @@ -452,32 +527,71 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const std::ostringstream message; message << "invalid function call\n" << "note: '" << n.children[0]->string() << "' (type: " << dataTypeName(n.children[0]->m_data_type) - << ") is not a function!" << std::ends; + << ") is not a function!"; throw ParseError(message.str(), n.begin()); } } else if (n.is_type<language::subscript_expression>()) { - Assert(n.children.size() == 2, "invalid number of sub-expressions in array subscript expression"); auto& array_expression = *n.children[0]; - auto& index_expression = *n.children[1]; - ASTNodeNaturalConversionChecker{index_expression, ASTNodeDataType::int_t}; - if (array_expression.m_data_type != ASTNodeDataType::vector_t) { + if (array_expression.m_data_type == ASTNodeDataType::vector_t) { + auto& index_expression = *n.children[1]; + ASTNodeNaturalConversionChecker{index_expression, ASTNodeDataType::build<ASTNodeDataType::int_t>()}; + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + if (n.children.size() != 2) { + std::ostringstream message; + message << "invalid index type: " << rang::fgB::yellow << dataTypeName(array_expression.m_data_type) + << rang::style::reset << " requires a single integer"; + throw ParseError(message.str(), index_expression.begin()); + } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + for (size_t i = 1; i < n.children.size(); ++i) { + ASTNodeNaturalConversionChecker{*n.children[i], ASTNodeDataType::build<ASTNodeDataType::int_t>()}; + } + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + if (n.children.size() != 3) { + std::ostringstream message; + message << "invalid index type: " << rang::fgB::yellow << dataTypeName(n.children[0]->m_data_type) + << rang::style::reset << " requires two integers"; + throw ParseError(message.str(), n.children[1]->begin()); + } + + } else { std::ostringstream message; - message << "invalid types '" << rang::fgB::yellow << dataTypeName(array_expression.m_data_type) - << rang::style::reset << '[' << dataTypeName(index_expression.m_data_type) << ']' - << "' for array subscript" << std::ends; + message << "invalid subscript expression: " << rang::fgB::yellow << dataTypeName(array_expression.m_data_type) + << rang::style::reset << " cannot be indexed"; throw ParseError(message.str(), n.begin()); - } else { - n.m_data_type = ASTNodeDataType::double_t; } - } else if (n.is_type<language::B_set>() or n.is_type<language::Z_set>() or n.is_type<language::N_set>() or - n.is_type<language::R_set>() or n.is_type<language::string_type>() or - n.is_type<language::vector_type>() or n.is_type<language::type_name_id>()) { - n.m_data_type = ASTNodeDataType::typename_t; + } else if (n.is_type<language::B_set>()) { + n.m_data_type = + ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::bool_t>()); + } else if (n.is_type<language::Z_set>()) { + n.m_data_type = + ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::int_t>()); + } else if (n.is_type<language::N_set>()) { + n.m_data_type = + ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>()); + } else if (n.is_type<language::string_type>()) { + n.m_data_type = + ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::string_t>()); + } else if (n.is_type<language::R_set>()) { + n.m_data_type = + ASTNodeDataType::build<ASTNodeDataType::typename_t>(ASTNodeDataType::build<ASTNodeDataType::double_t>()); + } else if (n.is_type<language::vector_type>()) { + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(getVectorDataType(n)); + } else if (n.is_type<language::matrix_type>()) { + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(getMatrixDataType(n)); } else if (n.is_type<language::name_list>() or n.is_type<language::lvalue_list>() or n.is_type<language::function_argument_list>() or n.is_type<language::expression_list>()) { - n.m_data_type = ASTNodeDataType::list_t; + std::vector<std::shared_ptr<const ASTNodeDataType>> sub_data_type_list; + sub_data_type_list.reserve(n.children.size()); + + for (size_t i = 0; i < n.children.size(); ++i) { + sub_data_type_list.push_back(std::make_shared<const ASTNodeDataType>(n.children[i]->m_data_type)); + } + + n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>(sub_data_type_list); } } } @@ -485,17 +599,9 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ASTNodeDataTypeBuilder::ASTNodeDataTypeBuilder(ASTNode& node) { Assert(node.is_root()); - node.m_data_type = ASTNodeDataType::void_t; + node.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); this->_buildNodeDataTypes(node); - FunctionTable& function_table = node.m_symbol_table->functionTable(); - for (size_t function_id = 0; function_id < function_table.size(); ++function_id) { - FunctionDescriptor& function_descriptor = function_table[function_id]; - ASTNode& function_expression = function_descriptor.definitionNode(); - - this->_buildNodeDataTypes(function_expression); - } - std::cout << " - build node data types\n"; } diff --git a/src/language/ast/ASTNodeDataTypeBuilder.hpp b/src/language/ast/ASTNodeDataTypeBuilder.hpp index 93643c82d80c9c222cb12b482416a661ff20fb31..be2c35adfb2ab81479ef27d385be880b35b0538c 100644 --- a/src/language/ast/ASTNodeDataTypeBuilder.hpp +++ b/src/language/ast/ASTNodeDataTypeBuilder.hpp @@ -6,7 +6,7 @@ class ASTNodeDataTypeBuilder { private: - ASTNodeDataType _buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNode& name_node) const; + void _buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNode& name_node) const; void _buildNodeDataTypes(ASTNode& node) const; diff --git a/src/language/ast/ASTNodeDataTypeFlattener.cpp b/src/language/ast/ASTNodeDataTypeFlattener.cpp index da2d064c6dd677e9eae4444cde59700c3a61ec2d..51d055836393b1ee680b7d983756bee39020d992 100644 --- a/src/language/ast/ASTNodeDataTypeFlattener.cpp +++ b/src/language/ast/ASTNodeDataTypeFlattener.cpp @@ -11,7 +11,7 @@ ASTNodeDataTypeFlattener::ASTNodeDataTypeFlattener(ASTNode& node, FlattenedDataT ASTNodeDataTypeFlattener{*child_node, flattened_datatype_list}; } } else if (node.is_type<language::function_evaluation>()) { - if (node.m_data_type != ASTNodeDataType::typename_t) { + if (node.m_data_type != ASTNodeDataType::list_t) { flattened_datatype_list.push_back({node.m_data_type, node}); } else { ASTNode& function_name_node = *node.children[0]; @@ -28,24 +28,8 @@ ASTNodeDataTypeFlattener::ASTNodeDataTypeFlattener(ASTNode& node, FlattenedDataT ASTNode& function_image_domain = *function_descriptor.domainMappingNode().children[1]; for (auto& image_sub_domain : function_image_domain.children) { - ASTNodeDataType data_type = ASTNodeDataType::undefined_t; - - if (image_sub_domain->is_type<language::B_set>()) { - data_type = ASTNodeDataType::bool_t; - } else if (image_sub_domain->is_type<language::Z_set>()) { - data_type = ASTNodeDataType::int_t; - } else if (image_sub_domain->is_type<language::N_set>()) { - data_type = ASTNodeDataType::unsigned_int_t; - } else if (image_sub_domain->is_type<language::R_set>()) { - data_type = ASTNodeDataType::double_t; - } else if (image_sub_domain->is_type<language::vector_type>()) { - data_type = getVectorDataType(*image_sub_domain); - } else if (image_sub_domain->is_type<language::string_type>()) { - data_type = ASTNodeDataType::string_t; - } - - Assert(data_type != ASTNodeDataType::undefined_t); - flattened_datatype_list.push_back({data_type, node}); + Assert(image_sub_domain->m_data_type == ASTNodeDataType::typename_t); + flattened_datatype_list.push_back({image_sub_domain->m_data_type.contentType(), node}); } break; } diff --git a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp index 072a2af6015eed0dab37468c0c5e7211ec887eff..6484b85b0fed5626df7f46ade79c479cfe5f23ae 100644 --- a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp @@ -2,11 +2,13 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/FunctionProcessor.hpp> +#include <language/node_processor/TupleToTinyMatrixProcessor.hpp> #include <language/node_processor/TupleToTinyVectorProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/FunctionTable.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/Exceptions.hpp> template <typename SymbolType> std::unique_ptr<IFunctionArgumentConverter> @@ -36,7 +38,7 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy // LCOV_EXCL_START default: { throw ParseError("unexpected error: invalid argument type 0", - std::vector{node_sub_data_type.m_parent_node.begin()}); + std::vector{node_sub_data_type.m_parent_node.begin()}); } // LCOV_EXCL_STOP } @@ -52,7 +54,7 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy } else { // LCOV_EXCL_START throw ParseError("unexpected error: invalid argument dimension", - std::vector{node_sub_data_type.m_parent_node.begin()}); + std::vector{node_sub_data_type.m_parent_node.begin()}); // LCOV_EXCL_STOP } } @@ -62,7 +64,7 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy } else { // LCOV_EXCL_START throw ParseError("unexpected error: invalid argument dimension", - std::vector{node_sub_data_type.m_parent_node.begin()}); + std::vector{node_sub_data_type.m_parent_node.begin()}); // LCOV_EXCL_STOP } } @@ -77,7 +79,49 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy // LCOV_EXCL_START default: { throw ParseError("unexpected error: invalid argument type", - std::vector{node_sub_data_type.m_parent_node.begin()}); + std::vector{node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + }; + + auto get_function_argument_converter_for_matrix = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + switch (node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((node_sub_data_type.m_data_type.nbRows() == parameter_v.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (node_sub_data_type.m_parent_node.children.size() == parameter_v.dimension()) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(parameter_id); + } + } + [[fallthrough]]; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); } // LCOV_EXCL_STOP } @@ -116,9 +160,28 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy return get_function_argument_converter_for_vector(TinyVector<3>{}); } } - [[fallthrough]]; + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP } + case ASTNodeDataType::matrix_t: { + Assert(parameter_symbol.attributes().dataType().nbRows() == parameter_symbol.attributes().dataType().nbColumns()); + switch (parameter_symbol.attributes().dataType().nbRows()) { + case 1: { + return get_function_argument_converter_for_matrix(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_converter_for_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_converter_for_matrix(TinyMatrix<3>{}); + } + } + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); + // LCOV_EXCL_STOP + } // LCOV_EXCL_START default: { throw ParseError("unexpected error: undefined parameter type", std::vector{m_node.begin()}); @@ -218,7 +281,7 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r // LCOV_EXCL_START default: { throw ParseError("unexpected error: undefined expression value type for function", - std::vector{node.children[1]->begin()}); + std::vector{node.children[1]->begin()}); } // LCOV_EXCL_STOP } @@ -233,7 +296,52 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r } else { // LCOV_EXCL_START throw ParseError("unexpected error: invalid dimension for returned vector", - std::vector{function_component_expression.begin()}); + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (function_component_expression.children.size() == return_v.dimension()) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, AggregateDataVariant>>( + function_component_expression); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension for returned vector", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (function_component_expression.is_type<language::integer>()) { + if (std::stoi(function_component_expression.string()) == 0) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ZeroType>>(function_component_expression); + } + } + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined expression value type for function", + std::vector{function_component_expression.begin()}); + // LCOV_EXCL_STOP + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: undefined expression value type for function", + std::vector{function_component_expression.begin()}); + } + // LCOV_EXCL_STOP + } + }; + + auto get_function_processor_for_expression_matrix = [&](const auto& return_v) -> std::unique_ptr<INodeProcessor> { + using ReturnT = std::decay_t<decltype(return_v)>; + switch (function_component_expression.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((function_component_expression.m_data_type.nbRows() == return_v.nbRows()) and + (function_component_expression.m_data_type.nbColumns() == return_v.nbColumns())) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ReturnT>>(function_component_expression); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension for returned vector", + std::vector{function_component_expression.begin()}); // LCOV_EXCL_STOP } } @@ -244,7 +352,7 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r } else { // LCOV_EXCL_START throw ParseError("unexpected error: invalid dimension for returned vector", - std::vector{function_component_expression.begin()}); + std::vector{function_component_expression.begin()}); // LCOV_EXCL_STOP } } @@ -256,13 +364,13 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r } // LCOV_EXCL_START throw ParseError("unexpected error: undefined expression value type for function", - std::vector{function_component_expression.begin()}); + std::vector{function_component_expression.begin()}); // LCOV_EXCL_STOP } // LCOV_EXCL_START default: { throw ParseError("unexpected error: undefined expression value type for function", - std::vector{function_component_expression.begin()}); + std::vector{function_component_expression.begin()}); } // LCOV_EXCL_STOP } @@ -304,6 +412,30 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType& r // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(return_value_type.nbRows() == return_value_type.nbColumns()); + + switch (return_value_type.nbRows()) { + case 1: { + if (function_component_expression.m_data_type == ASTNodeDataType::matrix_t) { + return get_function_processor_for_expression_matrix(TinyMatrix<1>{}); + } else { + return get_function_processor_for_expression_value(TinyMatrix<1>{}); + } + } + case 2: { + return get_function_processor_for_expression_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_processor_for_expression_matrix(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid dimension in returned type", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_processor_for_expression_value(std::string{}); } @@ -330,32 +462,12 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node std::unique_ptr function_processor = this->_buildArgumentConverter(function_descriptor, node); - auto add_component_expression = [&](ASTNode& expression_node, ASTNode& domain_node) { - ASTNodeDataType return_value_type = ASTNodeDataType::undefined_t; - - ASTNode& image_domain_node = domain_node; + auto add_component_expression = [&](ASTNode& expression_node, const ASTNode& image_domain_node) { + Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t); - if (image_domain_node.is_type<language::B_set>()) { - return_value_type = ASTNodeDataType::bool_t; - } else if (image_domain_node.is_type<language::Z_set>()) { - return_value_type = ASTNodeDataType::int_t; - } else if (image_domain_node.is_type<language::N_set>()) { - return_value_type = ASTNodeDataType::unsigned_int_t; - } else if (image_domain_node.is_type<language::R_set>()) { - return_value_type = ASTNodeDataType::double_t; - } else if (image_domain_node.is_type<language::vector_type>()) { - return_value_type = getVectorDataType(image_domain_node); - } else if (image_domain_node.is_type<language::string_type>()) { - return_value_type = ASTNodeDataType::string_t; - } + const ASTNodeDataType return_value_type = image_domain_node.m_data_type.contentType(); - Assert(return_value_type != ASTNodeDataType::undefined_t); - - if ((return_value_type == ASTNodeDataType::vector_t) and (return_value_type.dimension() == 1)) { - ASTNodeNaturalConversionChecker{expression_node, ASTNodeDataType::double_t}; - } else { - ASTNodeNaturalConversionChecker{expression_node, return_value_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{expression_node, return_value_type}; function_processor->addFunctionExpressionProcessor( this->_getFunctionProcessor(return_value_type, node, expression_node)); @@ -364,20 +476,21 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node ASTNode& function_image_domain = *function_descriptor.domainMappingNode().children[1]; ASTNode& function_expression = *function_descriptor.definitionNode().children[1]; + Assert(function_image_domain.m_data_type == ASTNodeDataType::typename_t); + const ASTNodeDataType function_return_type = function_image_domain.m_data_type.contentType(); + if (function_image_domain.is_type<language::vector_type>()) { ASTNodeDataType vector_type = getVectorDataType(function_image_domain); - if ((vector_type.dimension() == 1) and (function_expression.m_data_type != ASTNodeDataType::vector_t)) { - ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::double_t}; - } else { - ASTNodeNaturalConversionChecker{function_expression, vector_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, vector_type}; + if (function_expression.is_type<language::expression_list>()) { Assert(vector_type.dimension() == function_expression.children.size()); for (size_t i = 0; i < vector_type.dimension(); ++i) { function_processor->addFunctionExpressionProcessor( - this->_getFunctionProcessor(ASTNodeDataType::double_t, node, *function_expression.children[i])); + this->_getFunctionProcessor(ASTNodeDataType::build<ASTNodeDataType::double_t>(), node, + *function_expression.children[i])); } switch (vector_type.dimension()) { @@ -433,6 +546,73 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node node.m_node_processor = std::move(function_processor); } + } else if (function_image_domain.is_type<language::matrix_type>()) { + ASTNodeDataType matrix_type = getMatrixDataType(function_image_domain); + + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, matrix_type}; + + if (function_expression.is_type<language::expression_list>()) { + Assert(matrix_type.nbRows() * matrix_type.nbColumns() == function_expression.children.size()); + + for (size_t i = 0; i < matrix_type.nbRows() * matrix_type.nbColumns(); ++i) { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(ASTNodeDataType::build<ASTNodeDataType::double_t>(), node, + *function_expression.children[i])); + } + + switch (matrix_type.nbRows()) { + case 2: { + node.m_node_processor = + std::make_unique<TupleToTinyMatrixProcessor<FunctionProcessor, 2>>(node, std::move(function_processor)); + break; + } + case 3: { + node.m_node_processor = + std::make_unique<TupleToTinyMatrixProcessor<FunctionProcessor, 3>>(node, std::move(function_processor)); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid matrix_t dimensions", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (function_expression.is_type<language::integer>()) { + if (std::stoi(function_expression.string()) == 0) { + switch (matrix_type.nbRows()) { + case 1: { + node.m_node_processor = + std::make_unique<FunctionExpressionProcessor<TinyMatrix<1>, ZeroType>>(function_expression); + break; + } + case 2: { + node.m_node_processor = + std::make_unique<FunctionExpressionProcessor<TinyMatrix<2>, ZeroType>>(function_expression); + break; + } + case 3: { + node.m_node_processor = + std::make_unique<FunctionExpressionProcessor<TinyMatrix<3>, ZeroType>>(function_expression); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("invalid matrix dimensions"); + } + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting 0"); + // LCOV_EXCL_STOP + } + } else { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(matrix_type, node, function_expression)); + + node.m_node_processor = std::move(function_processor); + } + } else { if (function_expression.is_type<language::expression_list>()) { ASTNode& image_domain_node = function_image_domain; diff --git a/src/language/ast/ASTNodeIncDecExpressionBuilder.cpp b/src/language/ast/ASTNodeIncDecExpressionBuilder.cpp index 71e5c3c4b16d8535091b98cab298ff6624661508..88b1758e757797d0844a06185105480e341635c8 100644 --- a/src/language/ast/ASTNodeIncDecExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeIncDecExpressionBuilder.cpp @@ -1,54 +1,47 @@ #include <language/ast/ASTNodeIncDecExpressionBuilder.hpp> #include <language/PEGGrammar.hpp> -#include <language/node_processor/IncDecExpressionProcessor.hpp> +#include <language/utils/IncDecOperatorMangler.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> ASTNodeIncDecExpressionBuilder::ASTNodeIncDecExpressionBuilder(ASTNode& n) { - auto set_inc_dec_operator_processor = [](ASTNode& n, const auto& operator_v) { - auto set_inc_dec_processor_for_value = [&](const ASTNodeDataType& data_type) { - using OperatorT = std::decay_t<decltype(operator_v)>; - switch (data_type) { - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<IncDecExpressionProcessor<OperatorT, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<IncDecExpressionProcessor<OperatorT, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<IncDecExpressionProcessor<OperatorT, double>>(n); - break; - } - default: { - throw ParseError("unexpected error: undefined data type for unary operator", std::vector{n.begin()}); - } - } - }; - - if (not n.children[0]->is_type<language::name>()) { - if (n.children[0]->is_type<language::post_minusminus>() or n.children[0]->is_type<language::post_plusplus>() or - n.children[0]->is_type<language::unary_minusminus>() or n.children[0]->is_type<language::unary_plusplus>()) { - throw ParseError("chaining ++ or -- operators is not allowed", std::vector{n.children[0]->begin()}); - } else { - throw ParseError("invalid operand type for unary operator", std::vector{n.children[0]->begin()}); - } + const ASTNodeDataType& data_type = n.children[0]->m_data_type; + + if (not n.children[0]->is_type<language::name>()) { + std::ostringstream error_message; + error_message << "invalid operand type. ++/-- operators only apply to variables"; + + throw ParseError(error_message.str(), std::vector{n.children[0]->begin()}); + } + + const std::string inc_dec_operator_name = [&] { + if (n.is_type<language::unary_minusminus>()) { + return incDecOperatorMangler<language::unary_minusminus>(data_type); + } else if (n.is_type<language::unary_plusplus>()) { + return incDecOperatorMangler<language::unary_plusplus>(data_type); + } else if (n.is_type<language::post_minusminus>()) { + return incDecOperatorMangler<language::post_minusminus>(data_type); + } else if (n.is_type<language::post_plusplus>()) { + return incDecOperatorMangler<language::post_plusplus>(data_type); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined inc/dec operator", std::vector{n.begin()}); + // LCOV_EXCL_STOP } + }(); + + const auto& optional_processor_builder = + OperatorRepository::instance().getIncDecProcessorBuilder(inc_dec_operator_name); - set_inc_dec_processor_for_value(n.m_data_type); - }; - - if (n.is_type<language::unary_minusminus>()) { - set_inc_dec_operator_processor(n, language::unary_minusminus{}); - } else if (n.is_type<language::unary_plusplus>()) { - set_inc_dec_operator_processor(n, language::unary_plusplus{}); - } else if (n.is_type<language::post_minusminus>()) { - set_inc_dec_operator_processor(n, language::post_minusminus{}); - } else if (n.is_type<language::post_plusplus>()) { - set_inc_dec_operator_processor(n, language::post_plusplus{}); + if (optional_processor_builder.has_value()) { + n.m_node_processor = optional_processor_builder.value()->getNodeProcessor(n); } else { - throw ParseError("unexpected error: undefined increment/decrement operator", std::vector{n.begin()}); + std::ostringstream error_message; + error_message << "undefined affectation type: "; + error_message << rang::fgB::red << inc_dec_operator_name << rang::fg::reset; + + throw ParseError(error_message.str(), std::vector{n.children[0]->begin()}); } } diff --git a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp index daa5a0bd371150ffd3d33a60993a3d4d10c48d96..b135bbcd3f67127c23c044880d96ad28ee6527c5 100644 --- a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp @@ -2,8 +2,8 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNodeDataTypeFlattener.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> #include <language/node_processor/AffectationProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> #include <language/utils/ParseError.hpp> template <typename OperatorT> @@ -77,6 +77,44 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; + auto add_affectation_processor_for_matrix_data = [&](const auto& value, + const ASTNodeSubDataType& node_sub_data_type) { + using ValueT = std::decay_t<decltype(value)>; + if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) { + if ((node_sub_data_type.m_data_type == ASTNodeDataType::matrix_t) and + (node_sub_data_type.m_data_type.nbRows() == value.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == value.nbColumns())) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else { + add_affectation_processor_for_data(value, node_sub_data_type); + } + } else if constexpr (std::is_same_v<ValueT, TinyMatrix<2>> or std::is_same_v<ValueT, TinyMatrix<3>>) { + if ((node_sub_data_type.m_data_type == ASTNodeDataType::matrix_t) and + (node_sub_data_type.m_data_type.nbRows() == value.nbRows()) and + (node_sub_data_type.m_data_type.nbColumns() == value.nbColumns())) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else if ((node_sub_data_type.m_data_type == ASTNodeDataType::list_t) and + (node_sub_data_type.m_parent_node.children.size() == value.nbRows() * value.nbColumns())) { + list_affectation_processor->template add<ValueT, AggregateDataVariant>(value_node); + } else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { + list_affectation_processor->template add<ValueT, ZeroType>(value_node); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid operand value", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid dimension", std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } else { + throw ParseError("unexpected error: invalid value type", std::vector{node_sub_data_type.m_parent_node.begin()}); + } + }; + auto add_affectation_processor_for_string_data = [&](const ASTNodeSubDataType& node_sub_data_type) { if constexpr (std::is_same_v<OperatorT, language::eq_op>) { switch (node_sub_data_type.m_data_type) { @@ -176,6 +214,29 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } break; } + case ASTNodeDataType::matrix_t: { + Assert(value_type.nbRows() == value_type.nbColumns()); + switch (value_type.nbRows()) { + case 1: { + add_affectation_processor_for_matrix_data(TinyMatrix<1>{}, node_sub_data_type); + break; + } + case 2: { + add_affectation_processor_for_matrix_data(TinyMatrix<2>{}, node_sub_data_type); + break; + } + case 3: { + add_affectation_processor_for_matrix_data(TinyMatrix<3>{}, node_sub_data_type); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("invalid dimension", std::vector{value_node.begin()}); + } + // LCOV_EXCL_STOP + } + break; + } case ASTNodeDataType::string_t: { add_affectation_processor_for_string_data(node_sub_data_type); break; @@ -188,12 +249,7 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; - if ((value_node.m_data_type != rhs_node_sub_data_type.m_data_type) and - (value_node.m_data_type == ASTNodeDataType::vector_t) and (value_node.m_data_type.dimension() == 1)) { - ASTNodeNaturalConversionChecker{rhs_node_sub_data_type, ASTNodeDataType::double_t}; - } else { - ASTNodeNaturalConversionChecker{rhs_node_sub_data_type, value_node.m_data_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(rhs_node_sub_data_type, value_node.m_data_type); add_affectation_processor_for_value(value_node.m_data_type, rhs_node_sub_data_type); } @@ -211,7 +267,7 @@ ASTNodeListAffectationExpressionBuilder::_buildListAffectationProcessor() ASTNode& name_list_node = *m_node.children[0]; if (name_list_node.children.size() != flattened_rhs_data_type_list.size()) { - throw ParseError("incompatible list sizes in affectation", std::vector{m_node.begin()}); + throw ParseError("incompatible list sizes in affectation", std::vector{m_node.children[0]->begin()}); } using ListAffectationProcessorT = ListAffectationProcessor<OperatorT>; diff --git a/src/language/ast/ASTNodeNaturalConversionChecker.cpp b/src/language/ast/ASTNodeNaturalConversionChecker.cpp deleted file mode 100644 index 0faabac4fa7d0cb197ba3b8f35ffcfadce3f162a..0000000000000000000000000000000000000000 --- a/src/language/ast/ASTNodeNaturalConversionChecker.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> - -#include <language/PEGGrammar.hpp> -#include <language/utils/ParseError.hpp> -#include <utils/Exceptions.hpp> - -void -ASTNodeNaturalConversionChecker::_checkIsNaturalTypeConversion(const ASTNode& node, - const ASTNodeDataType& data_type, - const ASTNodeDataType& target_data_type) const -{ - if (not isNaturalConversion(data_type, target_data_type)) { - std::ostringstream error_message; - error_message << "invalid implicit conversion: "; - error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type) - << rang::fg::reset; - - if ((data_type == ASTNodeDataType::undefined_t) or (target_data_type == ASTNodeDataType::undefined_t)) { - throw UnexpectedError(error_message.str()); - } else { - throw ParseError(error_message.str(), node.begin()); - } - } -} - -void -ASTNodeNaturalConversionChecker::_checkIsNaturalExpressionConversion(const ASTNode& node, - const ASTNodeDataType& data_type, - const ASTNodeDataType& target_data_type) const -{ - if (target_data_type == ASTNodeDataType::vector_t) { - switch (node.m_data_type) { - case ASTNodeDataType::list_t: { - if (node.children.size() != target_data_type.dimension()) { - throw ParseError("incompatible dimensions in affectation", std::vector{node.begin()}); - } - for (const auto& child : node.children) { - this->_checkIsNaturalExpressionConversion(*child, child->m_data_type, ASTNodeDataType::double_t); - } - - break; - } - case ASTNodeDataType::vector_t: { - if (data_type.dimension() != target_data_type.dimension()) { - throw ParseError("incompatible dimensions in affectation", std::vector{node.begin()}); - } - break; - } - case ASTNodeDataType::int_t: { - if (node.is_type<language::integer>()) { - if (std::stoi(node.string()) == 0) { - break; - } - } - [[fallthrough]]; - } - default: { - this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); - } - } - } else if (target_data_type == ASTNodeDataType::tuple_t) { - const ASTNodeDataType& target_content_type = target_data_type.contentType(); - if (node.m_data_type == ASTNodeDataType::tuple_t) { - this->_checkIsNaturalExpressionConversion(node, data_type.contentType(), target_content_type); - } else if (node.m_data_type == ASTNodeDataType::list_t) { - if ((target_data_type.contentType() == ASTNodeDataType::vector_t) and - (target_data_type.contentType().dimension() == 1)) { - for (const auto& child : node.children) { - if (not isNaturalConversion(child->m_data_type, target_data_type)) { - this->_checkIsNaturalExpressionConversion(*child, child->m_data_type, ASTNodeDataType::double_t); - } - } - } else { - for (const auto& child : node.children) { - this->_checkIsNaturalExpressionConversion(*child, child->m_data_type, target_content_type); - } - } - } else { - if ((target_data_type.contentType() == ASTNodeDataType::vector_t) and - (target_data_type.contentType().dimension() == 1)) { - if (not isNaturalConversion(data_type, target_data_type)) { - this->_checkIsNaturalExpressionConversion(node, data_type, ASTNodeDataType::double_t); - } - } else { - this->_checkIsNaturalExpressionConversion(node, data_type, target_content_type); - } - } - } else { - this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); - } -} - -ASTNodeNaturalConversionChecker::ASTNodeNaturalConversionChecker(const ASTNode& data_node, - const ASTNodeDataType& target_data_type) -{ - this->_checkIsNaturalExpressionConversion(data_node, data_node.m_data_type, target_data_type); -} - -ASTNodeNaturalConversionChecker::ASTNodeNaturalConversionChecker(const ASTNodeSubDataType& data_node_sub_data_type, - const ASTNodeDataType& target_data_type) -{ - this->_checkIsNaturalExpressionConversion(data_node_sub_data_type.m_parent_node, data_node_sub_data_type.m_data_type, - target_data_type); -} diff --git a/src/language/ast/ASTNodeSubDataType.hpp b/src/language/ast/ASTNodeSubDataType.hpp index d7844d367ae686fe88e05bc119d198288eca45c6..358aca9a1419309796b9fcb9baf8eb9825e7a4eb 100644 --- a/src/language/ast/ASTNodeSubDataType.hpp +++ b/src/language/ast/ASTNodeSubDataType.hpp @@ -2,7 +2,7 @@ #define AST_NODE_SUB_DATA_TYPE_HPP #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> struct ASTNodeSubDataType { diff --git a/src/language/ast/ASTNodeUnaryOperatorExpressionBuilder.cpp b/src/language/ast/ASTNodeUnaryOperatorExpressionBuilder.cpp index 7d94c3f58fe1762b582a0d984edc797b0fafd021..e011e693000d5b68de6fdd0ea98b4e7afe01a6f6 100644 --- a/src/language/ast/ASTNodeUnaryOperatorExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeUnaryOperatorExpressionBuilder.cpp @@ -1,100 +1,35 @@ #include <language/ast/ASTNodeUnaryOperatorExpressionBuilder.hpp> #include <language/PEGGrammar.hpp> -#include <language/node_processor/UnaryExpressionProcessor.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/ParseError.hpp> +#include <language/utils/UnaryOperatorMangler.hpp> ASTNodeUnaryOperatorExpressionBuilder::ASTNodeUnaryOperatorExpressionBuilder(ASTNode& n) { - auto set_unary_operator_processor = [](ASTNode& n, const auto& operator_v) { - using OperatorT = std::decay_t<decltype(operator_v)>; + const ASTNodeDataType& data_type = n.children[0]->m_data_type; - auto set_unary_operator_processor_for_data = [&](const auto& value, const ASTNodeDataType& data_type) { - using ValueT = std::decay_t<decltype(value)>; - switch (data_type) { - case ASTNodeDataType::bool_t: { - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, bool>>(n); - break; - } - case ASTNodeDataType::unsigned_int_t: { - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, uint64_t>>(n); - break; - } - case ASTNodeDataType::int_t: { - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, int64_t>>(n); - break; - } - case ASTNodeDataType::double_t: { - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, double>>(n); - break; - } - default: { - throw ParseError("unexpected error: invalid operand type for unary operator", - std::vector{n.children[0]->begin()}); - } - } - }; + const std::string unary_operator_name = [&] { + if (n.is_type<language::unary_minus>()) { + return unaryOperatorMangler<language::unary_minus>(data_type); + } else if (n.is_type<language::unary_not>()) { + return unaryOperatorMangler<language::unary_not>(data_type); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: undefined unary operator", std::vector{n.begin()}); + // LCOV_EXCL_STOP + } + }(); - auto set_unary_operator_processor_for_value = [&](const ASTNodeDataType& value_type) { - const ASTNodeDataType data_type = n.children[0]->m_data_type; - switch (value_type) { - case ASTNodeDataType::bool_t: { - set_unary_operator_processor_for_data(bool{}, data_type); - break; - } - case ASTNodeDataType::int_t: { - set_unary_operator_processor_for_data(int64_t{}, data_type); - break; - } - case ASTNodeDataType::double_t: { - set_unary_operator_processor_for_data(double{}, data_type); - break; - } - case ASTNodeDataType::vector_t: { - if constexpr (std::is_same_v<OperatorT, language::unary_minus>) { - switch (data_type.dimension()) { - case 1: { - using ValueT = TinyVector<1>; - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, ValueT>>(n); - break; - } - case 2: { - using ValueT = TinyVector<2>; - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, ValueT>>(n); - break; - } - case 3: { - using ValueT = TinyVector<3>; - n.m_node_processor = std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, ValueT>>(n); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension", std::vector{n.begin()}); - } - // LCOV_EXCL_STOP - } - } else { - // LCOV_EXCL_START - throw ParseError("unexpected error: invalid unary operator for vector data", std::vector{n.begin()}); - // LCOV_EXCL_STOP - } - break; - } - default: { - throw ParseError("undefined value type for unary operator", std::vector{n.begin()}); - } - } - }; + const auto& optional_processor_builder = OperatorRepository::instance().getUnaryProcessorBuilder(unary_operator_name); - set_unary_operator_processor_for_value(n.m_data_type); - }; - - if (n.is_type<language::unary_minus>()) { - set_unary_operator_processor(n, language::unary_minus{}); - } else if (n.is_type<language::unary_not>()) { - set_unary_operator_processor(n, language::unary_not{}); + if (optional_processor_builder.has_value()) { + n.m_node_processor = optional_processor_builder.value()->getNodeProcessor(n); } else { - throw ParseError("unexpected error: undefined unary operator", std::vector{n.begin()}); + std::ostringstream error_message; + error_message << "undefined unary operator type: "; + error_message << rang::fgB::red << unary_operator_name << rang::fg::reset; + + throw ParseError(error_message.str(), std::vector{n.children[0]->begin()}); } } diff --git a/src/language/ast/ASTSymbolInitializationChecker.cpp b/src/language/ast/ASTSymbolInitializationChecker.cpp index 3a8c846a3626a3acb815a401767d898b378f546c..c37d4ba6bbe45563a71c876c5d1b70e992762729 100644 --- a/src/language/ast/ASTSymbolInitializationChecker.cpp +++ b/src/language/ast/ASTSymbolInitializationChecker.cpp @@ -22,7 +22,7 @@ ASTSymbolInitializationChecker::_checkSymbolInitialization(ASTNode& node) Assert(def_name_node.is_type<language::name>()); if (decl_name_node.string() != def_name_node.string()) { std::ostringstream os; - os << "invalid identifier, expecting '" << decl_name_node.string() << "'" << std::ends; + os << "invalid identifier, expecting '" << decl_name_node.string() << '\''; throw ParseError(os.str(), std::vector{def_name_node.begin()}); } }; @@ -37,22 +37,19 @@ ASTSymbolInitializationChecker::_checkSymbolInitialization(ASTNode& node) if (node.children.size() == 4) { ASTNode& decl_name_list_node = *node.children[0]; ASTNode& def_name_list_node = *node.children[2]; - Assert(def_name_list_node.is_type<language::name_list>()); - ASTNode& expression_list_node = *node.children[3]; - Assert(expression_list_node.is_type<language::expression_list>()); + + if (not def_name_list_node.is_type<language::name_list>()) { + throw ParseError("expecting a list of identifiers", std::vector{def_name_list_node.begin()}); + } if (decl_name_list_node.children.size() != def_name_list_node.children.size()) { std::ostringstream os; os << "invalid number of definition identifiers, expecting " << decl_name_list_node.children.size() - << " found " << def_name_list_node.children.size() << std::ends; + << " found " << def_name_list_node.children.size(); throw ParseError(os.str(), std::vector{def_name_list_node.begin()}); } - if (def_name_list_node.children.size() != expression_list_node.children.size()) { - std::ostringstream os; - os << "invalid number of definition expressions, expecting " << decl_name_list_node.children.size() - << " found " << expression_list_node.children.size() << std::ends; - throw ParseError(os.str(), std::vector{expression_list_node.begin()}); - } + + ASTNode& expression_list_node = *node.children[3]; this->_checkSymbolInitialization(expression_list_node); for (size_t i = 0; i < decl_name_list_node.children.size(); ++i) { diff --git a/src/language/ast/CMakeLists.txt b/src/language/ast/CMakeLists.txt index f4c31774302ccff20b3d24f686fd7e381ff43822..a901d7c09679c37d54420f51d347537fe5e18aef 100644 --- a/src/language/ast/CMakeLists.txt +++ b/src/language/ast/CMakeLists.txt @@ -9,7 +9,6 @@ add_library(PugsLanguageAST ASTNodeBuiltinFunctionExpressionBuilder.cpp ASTNodeDataTypeBuilder.cpp ASTNodeDataTypeChecker.cpp - ASTNodeDataType.cpp ASTNodeDataTypeFlattener.cpp ASTNodeDeclarationToAffectationConverter.cpp ASTNodeEmptyBlockCleaner.cpp @@ -19,7 +18,6 @@ add_library(PugsLanguageAST ASTNodeIncDecExpressionBuilder.cpp ASTNodeJumpPlacementChecker.cpp ASTNodeListAffectationExpressionBuilder.cpp - ASTNodeNaturalConversionChecker.cpp ASTNodeUnaryOperatorExpressionBuilder.cpp ASTSymbolInitializationChecker.cpp ASTSymbolTableBuilder.cpp diff --git a/src/language/modules/BuiltinModule.cpp b/src/language/modules/BuiltinModule.cpp index 0ab7911394a9ebb81454c3381a380ea4aca9c03e..22d23f8033604698f44ea18e42c63e9ba48ba345 100644 --- a/src/language/modules/BuiltinModule.cpp +++ b/src/language/modules/BuiltinModule.cpp @@ -6,6 +6,8 @@ #include <memory> +BuiltinModule::BuiltinModule(bool is_mandatory) : m_is_mandatory{is_mandatory} {} + void BuiltinModule::_addBuiltinFunction(const std::string& name, std::shared_ptr<IBuiltinFunctionEmbedder> builtin_function_embedder) diff --git a/src/language/modules/BuiltinModule.hpp b/src/language/modules/BuiltinModule.hpp index bf8743c9e5a9913dd6939e3d6cbe18495bccc27e..2be1d162578fb0d111e6eac7ae93e10dae981bce 100644 --- a/src/language/modules/BuiltinModule.hpp +++ b/src/language/modules/BuiltinModule.hpp @@ -1,8 +1,8 @@ #ifndef BUILTIN_MODULE_HPP #define BUILTIN_MODULE_HPP -#include <language/ast/ASTNodeDataType.hpp> #include <language/modules/IModule.hpp> +#include <language/utils/ASTNodeDataType.hpp> class IBuiltinFunctionEmbedder; class TypeDescriptor; @@ -18,7 +18,15 @@ class BuiltinModule : public IModule void _addTypeDescriptor(const ASTNodeDataType& type); + const bool m_is_mandatory; + public: + bool + isMandatory() const final + { + return m_is_mandatory; + } + const NameBuiltinFunctionMap& getNameBuiltinFunctionMap() const final { @@ -31,7 +39,7 @@ class BuiltinModule : public IModule return m_name_type_map; } - BuiltinModule() = default; + BuiltinModule(bool is_mandatory = false); ~BuiltinModule() = default; }; diff --git a/src/language/modules/CMakeLists.txt b/src/language/modules/CMakeLists.txt index f8cf2f0b445882a672d54e60dbc75858bbb246c7..020727478c5e64bac9ec0a780a0cfb80541fb28a 100644 --- a/src/language/modules/CMakeLists.txt +++ b/src/language/modules/CMakeLists.txt @@ -2,10 +2,13 @@ add_library(PugsLanguageModules BuiltinModule.cpp + CoreModule.cpp + LinearSolverModule.cpp MathModule.cpp MeshModule.cpp ModuleRepository.cpp SchemeModule.cpp + UtilsModule.cpp VTKModule.cpp ) diff --git a/src/language/modules/CoreModule.cpp b/src/language/modules/CoreModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..14a9548238c91682aac53850a6cdf6f36b4d9e93 --- /dev/null +++ b/src/language/modules/CoreModule.cpp @@ -0,0 +1,44 @@ +#include <language/modules/CoreModule.hpp> + +#include <language/modules/CoreModule.hpp> +#include <language/modules/ModuleRepository.hpp> +#include <language/utils/ASTExecutionInfo.hpp> +#include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <utils/PugsUtils.hpp> + +CoreModule::CoreModule() : BuiltinModule(true) +{ + this->_addBuiltinFunction("getPugsVersion", std::make_shared<BuiltinFunctionEmbedder<std::string(void)>>( + + []() -> std::string { return pugsVersion(); } + + )); + + this->_addBuiltinFunction("getPugsBuildInfo", std::make_shared<BuiltinFunctionEmbedder<std::string(void)>>( + + []() -> std::string { return pugsBuildInfo(); } + + )); + + this->_addBuiltinFunction("getAvailableModules", std::make_shared<BuiltinFunctionEmbedder<std::string()>>( + + []() -> std::string { + const ModuleRepository& repository = + ASTExecutionInfo::current().moduleRepository(); + + return repository.getAvailableModules(); + } + + )); + + this->_addBuiltinFunction("getModuleInfo", std::make_shared<BuiltinFunctionEmbedder<std::string(const std::string&)>>( + + [](const std::string& module_name) -> std::string { + const ModuleRepository& repository = + ASTExecutionInfo::current().moduleRepository(); + + return repository.getModuleInfo(module_name); + } + + )); +} diff --git a/src/language/modules/CoreModule.hpp b/src/language/modules/CoreModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..963719be2bce1b841fe9cee05db5b957ad53ba2b --- /dev/null +++ b/src/language/modules/CoreModule.hpp @@ -0,0 +1,19 @@ +#ifndef CORE_MODULE_HPP +#define CORE_MODULE_HPP + +#include <language/modules/BuiltinModule.hpp> + +class CoreModule : public BuiltinModule +{ + public: + std::string_view + name() const final + { + return "core"; + } + + CoreModule(); + ~CoreModule() = default; +}; + +#endif // CORE_MODULE_HPP diff --git a/src/language/modules/IModule.hpp b/src/language/modules/IModule.hpp index b839ed496d8d032f660a6d21c484c7ca342c5f98..48e4b15360de922a83025921d1058095cc205d89 100644 --- a/src/language/modules/IModule.hpp +++ b/src/language/modules/IModule.hpp @@ -19,6 +19,8 @@ class IModule IModule(IModule&&) = default; IModule& operator=(IModule&&) = default; + virtual bool isMandatory() const = 0; + virtual const NameBuiltinFunctionMap& getNameBuiltinFunctionMap() const = 0; virtual const NameTypeMap& getNameTypeMap() const = 0; diff --git a/src/language/modules/LinearSolverModule.cpp b/src/language/modules/LinearSolverModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5bb2cf90058b9232bd84e2f2bbdd71eb02ef3c7 --- /dev/null +++ b/src/language/modules/LinearSolverModule.cpp @@ -0,0 +1,92 @@ +#include <language/modules/LinearSolverModule.hpp> + +#include <algebra/LinearSolver.hpp> +#include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/TypeDescriptor.hpp> + +LinearSolverModule::LinearSolverModule() +{ + this->_addBuiltinFunction("setLSVerbosity", std::make_shared<BuiltinFunctionEmbedder<void(const bool&)>>( + + [](const bool& verbose) -> void { + LinearSolverOptions::default_options.verbose() = verbose; + } + + )); + + this->_addBuiltinFunction("setLSEpsilon", std::make_shared<BuiltinFunctionEmbedder<void(const double&)>>( + + [](const double& epsilon) -> void { + LinearSolverOptions::default_options.epsilon() = epsilon; + } + + )); + + this->_addBuiltinFunction("setLSMaxIter", std::make_shared<BuiltinFunctionEmbedder<void(const uint64_t&)>>( + + [](const uint64_t& max_iter) -> void { + LinearSolverOptions::default_options.maximumIteration() = max_iter; + } + + )); + + this->_addBuiltinFunction("setLSLibrary", std::make_shared<BuiltinFunctionEmbedder<void(const std::string&)>>( + + [](const std::string& library_name) -> void { + LinearSolverOptions::default_options.library() = + getLSEnumFromName<LSLibrary>(library_name); + } + + )); + + this->_addBuiltinFunction("setLSMethod", std::make_shared<BuiltinFunctionEmbedder<void(const std::string&)>>( + + [](const std::string& method_name) -> void { + LinearSolverOptions::default_options.method() = + getLSEnumFromName<LSMethod>(method_name); + } + + )); + + this->_addBuiltinFunction("setLSPrecond", std::make_shared<BuiltinFunctionEmbedder<void(const std::string&)>>( + + [](const std::string& precond_name) -> void { + LinearSolverOptions::default_options.precond() = + getLSEnumFromName<LSPrecond>(precond_name); + } + + )); + + this->_addBuiltinFunction("getLSOptions", std::make_shared<BuiltinFunctionEmbedder<std::string()>>( + + []() -> std::string { + std::ostringstream os; + os << rang::fgB::yellow << "Linear solver options" << rang::style::reset + << '\n'; + os << LinearSolverOptions::default_options; + return os.str(); + } + + )); + + this->_addBuiltinFunction("getLSAvailable", std::make_shared<BuiltinFunctionEmbedder<std::string()>>( + + []() -> std::string { + std::ostringstream os; + + os << rang::fgB::yellow << "Available linear solver options" + << rang::style::reset << '\n'; + os << rang::fgB::blue << " libraries" << rang::style::reset << '\n'; + + printLSEnumListNames<LSLibrary>(os); + os << rang::fgB::blue << " methods" << rang::style::reset << '\n'; + printLSEnumListNames<LSMethod>(os); + os << rang::fgB::blue << " preconditioners" << rang::style::reset + << '\n'; + printLSEnumListNames<LSPrecond>(os); + + return os.str(); + } + + )); +} diff --git a/src/language/modules/LinearSolverModule.hpp b/src/language/modules/LinearSolverModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5dc7ae64efd3ed43f554e821693f761b0a66d412 --- /dev/null +++ b/src/language/modules/LinearSolverModule.hpp @@ -0,0 +1,19 @@ +#ifndef LINEAR_SOLVER_MODULE_HPP +#define LINEAR_SOLVER_MODULE_HPP + +#include <language/modules/BuiltinModule.hpp> + +class LinearSolverModule : public BuiltinModule +{ + public: + std::string_view + name() const final + { + return "linear_solver"; + } + + LinearSolverModule(); + ~LinearSolverModule() = default; +}; + +#endif // LINEAR_SOLVER_MODULE_HPP diff --git a/src/language/modules/MeshModule.hpp b/src/language/modules/MeshModule.hpp index f2c3d3d6360e5fd378d128d3b6b57e0de5a04509..ebb1383107ffd288001a48a5abced722ef37a4bd 100644 --- a/src/language/modules/MeshModule.hpp +++ b/src/language/modules/MeshModule.hpp @@ -8,7 +8,8 @@ class IMesh; template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IMesh>> = {ASTNodeDataType::type_id_t, "mesh"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IMesh>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("mesh"); class MeshModule : public BuiltinModule { diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index aa5c1fce3ee96c5318d930d75f2047df0d209735..65bfcafd1f1c85f077a97f6e932fa97da0079537 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -1,15 +1,22 @@ #include <language/modules/ModuleRepository.hpp> #include <language/ast/ASTNode.hpp> +#include <language/modules/CoreModule.hpp> +#include <language/modules/LinearSolverModule.hpp> #include <language/modules/MathModule.hpp> #include <language/modules/MeshModule.hpp> #include <language/modules/SchemeModule.hpp> +#include <language/modules/UtilsModule.hpp> #include <language/modules/VTKModule.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/ParseError.hpp> #include <language/utils/SymbolTable.hpp> +#include <language/utils/TypeDescriptor.hpp> #include <utils/PugsAssert.hpp> +#include <algorithm> + void ModuleRepository::_subscribe(std::unique_ptr<IModule> m) { @@ -19,30 +26,32 @@ ModuleRepository::_subscribe(std::unique_ptr<IModule> m) ModuleRepository::ModuleRepository() { + this->_subscribe(std::make_unique<CoreModule>()); + this->_subscribe(std::make_unique<LinearSolverModule>()); this->_subscribe(std::make_unique<MathModule>()); this->_subscribe(std::make_unique<MeshModule>()); - this->_subscribe(std::make_unique<VTKModule>()); this->_subscribe(std::make_unique<SchemeModule>()); + this->_subscribe(std::make_unique<UtilsModule>()); + this->_subscribe(std::make_unique<VTKModule>()); } template <typename NameEmbedderMapT, typename EmbedderTableT> void -ModuleRepository::_populateEmbedderTableT(const ASTNode& module_name_node, +ModuleRepository::_populateEmbedderTableT(const ASTNode& module_node, + const std::string& module_name, const NameEmbedderMapT& name_embedder_map, const ASTNodeDataType& data_type, SymbolTable& symbol_table, EmbedderTableT& embedder_table) { - const std::string& module_name = module_name_node.string(); - for (auto [symbol_name, embedded] : name_embedder_map) { - auto [i_symbol, success] = symbol_table.add(symbol_name, module_name_node.begin()); + auto [i_symbol, success] = symbol_table.add(symbol_name, module_node.begin()); if (not success) { std::ostringstream error_message; error_message << "importing module '" << module_name << "', cannot add symbol '" << symbol_name << "', it is already defined!"; - throw ParseError(error_message.str(), module_name_node.begin()); + throw ParseError(error_message.str(), module_node.begin()); } i_symbol->attributes().setDataType(data_type); @@ -62,13 +71,88 @@ ModuleRepository::populateSymbolTable(const ASTNode& module_name_node, SymbolTab if (i_module != m_module_set.end()) { const IModule& populating_module = *i_module->second; - this->_populateEmbedderTableT(module_name_node, populating_module.getNameBuiltinFunctionMap(), - ASTNodeDataType::builtin_function_t, symbol_table, + this->_populateEmbedderTableT(module_name_node, module_name, populating_module.getNameBuiltinFunctionMap(), + ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>(), symbol_table, symbol_table.builtinFunctionEmbedderTable()); - this->_populateEmbedderTableT(module_name_node, populating_module.getNameTypeMap(), ASTNodeDataType::type_name_id_t, - symbol_table, symbol_table.typeEmbedderTable()); + this->_populateEmbedderTableT(module_name_node, module_name, populating_module.getNameTypeMap(), + ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>(), symbol_table, + symbol_table.typeEmbedderTable()); + + for (auto [symbol_name, embedded] : populating_module.getNameTypeMap()) { + BasicAffectationRegisterFor<EmbeddedData>(ASTNodeDataType::build<ASTNodeDataType::type_id_t>(symbol_name)); + } + } else { throw ParseError(std::string{"could not find module "} + module_name, std::vector{module_name_node.begin()}); } } + +void +ModuleRepository::populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table) +{ + for (auto&& [module_name, i_module] : m_module_set) { + if (i_module->isMandatory()) { + this->_populateEmbedderTableT(root_node, module_name, i_module->getNameBuiltinFunctionMap(), + ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>(), symbol_table, + symbol_table.builtinFunctionEmbedderTable()); + + this->_populateEmbedderTableT(root_node, module_name, i_module->getNameTypeMap(), + ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>(), symbol_table, + symbol_table.typeEmbedderTable()); + } + } +} + +std::string +ModuleRepository::getAvailableModules() const +{ + std::stringstream os; + os << rang::fgB::yellow << "Available modules" << rang::fg::blue << " [modules tagged with a " << rang::style::reset + << rang::style::bold << '*' << rang::style::reset << rang::fg::blue << " are automatically imported]" + << rang::style::reset << '\n'; + for (auto& [name, i_module] : m_module_set) { + if (i_module->isMandatory()) { + os << rang::style::bold << " *" << rang::style::reset; + } else { + os << " "; + } + os << rang::fgB::green << name << rang::style::reset << '\n'; + } + + return os.str(); +} + +std::string +ModuleRepository::getModuleInfo(const std::string& module_name) const +{ + std::stringstream os; + auto i_module = m_module_set.find(module_name); + if (i_module != m_module_set.end()) { + os << rang::fgB::yellow << "Module '" << rang::fgB::blue << module_name << rang::fgB::yellow << "' provides" + << rang::style::reset << '\n'; + const auto& builtin_function_map = i_module->second->getNameBuiltinFunctionMap(); + if (builtin_function_map.size() > 0) { + os << " functions\n"; + for (auto& [name, function] : builtin_function_map) { + os << " " << rang::fgB::green << name << rang::style::reset << ": "; + os << dataTypeName(function->getParameterDataTypes()); + os << rang::fgB::yellow << " -> " << rang::style::reset; + os << dataTypeName(function->getReturnDataType()) << '\n'; + } + } + + const auto& builtin_type_map = i_module->second->getNameTypeMap(); + if (builtin_type_map.size() > 0) { + os << " types\n"; + for (auto& [name, descriptor] : builtin_type_map) { + os << " " << rang::fgB::green << name << rang::style::reset << '\n'; + } + } + + } else { + throw NormalError(std::string{"could not find module "} + module_name); + } + + return os.str(); +} diff --git a/src/language/modules/ModuleRepository.hpp b/src/language/modules/ModuleRepository.hpp index f928a1552563612ed77434e4295a2d31fb2969c1..c224f6491521ad9757ad5253011604e863a10503 100644 --- a/src/language/modules/ModuleRepository.hpp +++ b/src/language/modules/ModuleRepository.hpp @@ -19,7 +19,8 @@ class ModuleRepository void _subscribe(std::unique_ptr<IModule> a); template <typename NameEmbedderMapT, typename EmbedderTableT> - void _populateEmbedderTableT(const ASTNode& module_name_node, + void _populateEmbedderTableT(const ASTNode& module_node, + const std::string& module_name, const NameEmbedderMapT& name_embedder_map, const ASTNodeDataType& data_type, SymbolTable& symbol_table, @@ -27,6 +28,10 @@ class ModuleRepository public: void populateSymbolTable(const ASTNode& module_name_node, SymbolTable& symbol_table); + void populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table); + + std::string getAvailableModules() const; + std::string getModuleInfo(const std::string& module_name) const; const ModuleRepository& operator=(const ModuleRepository&) = delete; const ModuleRepository& operator=(ModuleRepository&&) = delete; diff --git a/src/language/modules/SchemeModule.hpp b/src/language/modules/SchemeModule.hpp index b140295b486a33128295fb9829d2c6d540277948..963ed4af02a3c8e7895c89f705d15bac1ab8d521 100644 --- a/src/language/modules/SchemeModule.hpp +++ b/src/language/modules/SchemeModule.hpp @@ -9,12 +9,12 @@ class IBoundaryDescriptor; template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>> = - {ASTNodeDataType::type_id_t, "boundary"}; + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("boundary"); class IBoundaryConditionDescriptor; template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IBoundaryConditionDescriptor>> = - {ASTNodeDataType::type_id_t, "boundary_condition"}; + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("boundary_condition"); template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const BasisType>> = {ASTNodeDataType::type_id_t, diff --git a/src/language/modules/UtilsModule.cpp b/src/language/modules/UtilsModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b2271158c1548d05b39c379782abb3cc4754d66 --- /dev/null +++ b/src/language/modules/UtilsModule.cpp @@ -0,0 +1,69 @@ +#include <language/modules/UtilsModule.hpp> + +#include <language/utils/ASTDotPrinter.hpp> +#include <language/utils/ASTExecutionInfo.hpp> +#include <language/utils/ASTPrinter.hpp> +#include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <fstream> + +UtilsModule::UtilsModule() +{ + this->_addBuiltinFunction("getAST", std::make_shared<BuiltinFunctionEmbedder<std::string(void)>>( + + []() -> std::string { + const auto& root_node = ASTExecutionInfo::current().rootNode(); + + std::ostringstream os; + os << ASTPrinter{root_node}; + + return os.str(); + } + + )); + + this->_addBuiltinFunction("saveASTDot", std::make_shared<BuiltinFunctionEmbedder<void(const std::string&)>>( + + [](const std::string& dot_filename) -> void { + const auto& root_node = ASTExecutionInfo::current().rootNode(); + + std::ofstream fout(dot_filename); + + if (not fout) { + std::ostringstream os; + os << "could not create file '" << dot_filename << "'\n"; + throw NormalError(os.str()); + } + + ASTDotPrinter dot_printer{root_node}; + fout << dot_printer; + + if (not fout) { + std::ostringstream os; + os << "could not write AST to '" << dot_filename << "'\n"; + throw NormalError(os.str()); + } + } + + )); + + this->_addBuiltinFunction("getFunctionAST", + std::make_shared<BuiltinFunctionEmbedder<std::string(const FunctionSymbolId&)>>( + + [](const FunctionSymbolId& function_symbol_id) -> std::string { + auto& function_table = function_symbol_id.symbolTable().functionTable(); + + const auto& function_descriptor = function_table[function_symbol_id.id()]; + + std::ostringstream os; + os << function_descriptor.name() << ": domain mapping\n"; + os << ASTPrinter(function_descriptor.domainMappingNode()); + os << function_descriptor.name() << ": definition\n"; + os << ASTPrinter(function_descriptor.definitionNode()); + + return os.str(); + } + + )); +} diff --git a/src/language/modules/UtilsModule.hpp b/src/language/modules/UtilsModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f7580a0422d3d8dabbd8c931b99dbfc7601f84da --- /dev/null +++ b/src/language/modules/UtilsModule.hpp @@ -0,0 +1,19 @@ +#ifndef UTILS_MODULE_HPP +#define UTILS_MODULE_HPP + +#include <language/modules/BuiltinModule.hpp> + +class UtilsModule : public BuiltinModule +{ + public: + std::string_view + name() const final + { + return "utils"; + } + + UtilsModule(); + ~UtilsModule() = default; +}; + +#endif // UTILS_MODULE_HPP diff --git a/src/language/node_processor/ASTNodeExpressionListProcessor.hpp b/src/language/node_processor/ASTNodeExpressionListProcessor.hpp index 7ed93683d1b6bcfd54e624385429bce4319976d5..713eda89aec0804efb4bcd323cb9a65c10b8967d 100644 --- a/src/language/node_processor/ASTNodeExpressionListProcessor.hpp +++ b/src/language/node_processor/ASTNodeExpressionListProcessor.hpp @@ -42,7 +42,7 @@ class ASTNodeExpressionListProcessor final : public INodeProcessor for (auto& child : m_node.children) { if (child->is_type<language::function_evaluation>()) { - if (child->m_data_type != ASTNodeDataType::typename_t) { + if (child->m_data_type != ASTNodeDataType::list_t) { ++number_of_values; } else { ASTNode& function_name_node = *child->children[0]; diff --git a/src/language/node_processor/AffectationProcessor.hpp b/src/language/node_processor/AffectationProcessor.hpp index 6898bb2c2b25e4fe44a8f6e9ad271930086f0a81..959578619b023bc28236b1ae52141b018aad8c69 100644 --- a/src/language/node_processor/AffectationProcessor.hpp +++ b/src/language/node_processor/AffectationProcessor.hpp @@ -105,7 +105,7 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs = std::to_string(std::get<DataT>(rhs)); } else { std::ostringstream os; - os << std::get<DataT>(rhs) << std::ends; + os << std::get<DataT>(rhs); m_lhs = os.str(); } } else { @@ -115,7 +115,7 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs += std::to_string(std::get<DataT>(rhs)); } else { std::ostringstream os; - os << std::get<DataT>(rhs) << std::ends; + os << std::get<DataT>(rhs); m_lhs += os.str(); } } @@ -125,18 +125,40 @@ class AffectationExecutor final : public IAffectationExecutor m_lhs = std::get<DataT>(rhs); } else if constexpr (std::is_same_v<DataT, AggregateDataVariant>) { const AggregateDataVariant& v = std::get<AggregateDataVariant>(rhs); - static_assert(is_tiny_vector_v<ValueT>, "expecting lhs TinyVector"); - for (size_t i = 0; i < m_lhs.dimension(); ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_convertible_v<Vi_T, double>) { - m_lhs[i] = vi; - } else { - throw UnexpectedError("unexpected rhs type in affectation"); - } - }, - v[i]); + if constexpr (is_tiny_vector_v<ValueT>) { + for (size_t i = 0; i < m_lhs.dimension(); ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + m_lhs[i] = vi; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + v[i]); + } + } else if constexpr (is_tiny_matrix_v<ValueT>) { + for (size_t i = 0, l = 0; i < m_lhs.nbRows(); ++i) { + for (size_t j = 0; j < m_lhs.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Aij_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_convertible_v<Aij_T, double>) { + m_lhs(i, j) = Aij; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + } else { + static_assert(is_tiny_matrix_v<ValueT> or is_tiny_vector_v<ValueT>, "invalid rhs type"); } } else if constexpr (std::is_same_v<TinyVector<1>, ValueT>) { std::visit( @@ -145,10 +167,27 @@ class AffectationExecutor final : public IAffectationExecutor if constexpr (std::is_convertible_v<Vi_T, double>) { m_lhs = v; } else { + // LCOV_EXCL_START throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP } }, rhs); + } else if constexpr (std::is_same_v<TinyMatrix<1>, ValueT>) { + std::visit( + [&](auto&& v) { + using Vi_T = std::decay_t<decltype(v)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + m_lhs = v; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected rhs type in affectation"); + // LCOV_EXCL_STOP + } + }, + rhs); + } else { + throw UnexpectedError("invalid value type"); } } else { AffOp<OperatorT>().eval(m_lhs, std::get<DataT>(rhs)); @@ -164,7 +203,93 @@ class AffectationExecutor final : public IAffectationExecutor }; template <typename OperatorT, typename ArrayT, typename ValueT, typename DataT> -class ComponentAffectationExecutor final : public IAffectationExecutor +class MatrixComponentAffectationExecutor final : public IAffectationExecutor +{ + private: + ArrayT& m_lhs_array; + ASTNode& m_index0_expression; + ASTNode& m_index1_expression; + + static inline const bool m_is_defined{[] { + if constexpr (not std::is_same_v<typename ArrayT::data_type, ValueT>) { + return false; + } else if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) { + if constexpr (not std::is_same_v<OperatorT, language::eq_op>) { + return false; + } + } + return true; + }()}; + + public: + MatrixComponentAffectationExecutor(ASTNode& node, + ArrayT& lhs_array, + ASTNode& index0_expression, + ASTNode& index1_expression) + : m_lhs_array{lhs_array}, m_index0_expression{index0_expression}, m_index1_expression{index1_expression} + { + // LCOV_EXCL_START + if constexpr (not m_is_defined) { + throw ParseError("unexpected error: invalid operands to affectation expression", std::vector{node.begin()}); + } + // LCOV_EXCL_STOP + } + + PUGS_INLINE void + affect(ExecutionPolicy& exec_policy, DataVariant&& rhs) + { + if constexpr (m_is_defined) { + auto get_index_value = [&](DataVariant&& value_variant) -> int64_t { + int64_t index_value = 0; + std::visit( + [&](auto&& value) { + using IndexValueT = std::decay_t<decltype(value)>; + if constexpr (std::is_integral_v<IndexValueT>) { + index_value = value; + } else { + // LCOV_EXCL_START + throw UnexpectedError("invalid index type"); + // LCOV_EXCL_STOP + } + }, + value_variant); + return index_value; + }; + + const int64_t index0_value = get_index_value(m_index0_expression.execute(exec_policy)); + const int64_t index1_value = get_index_value(m_index1_expression.execute(exec_policy)); + + if constexpr (std::is_same_v<ValueT, std::string>) { + if constexpr (std::is_same_v<OperatorT, language::eq_op>) { + if constexpr (std::is_same_v<std::string, DataT>) { + m_lhs_array(index0_value, index1_value) = std::get<DataT>(rhs); + } else { + m_lhs_array(index0_value, index1_value) = std::to_string(std::get<DataT>(rhs)); + } + } else { + if constexpr (std::is_same_v<std::string, DataT>) { + m_lhs_array(index0_value, index1_value) += std::get<std::string>(rhs); + } else { + m_lhs_array(index0_value, index1_value) += std::to_string(std::get<DataT>(rhs)); + } + } + } else { + if constexpr (std::is_same_v<OperatorT, language::eq_op>) { + if constexpr (std::is_same_v<ValueT, DataT>) { + m_lhs_array(index0_value, index1_value) = std::get<DataT>(rhs); + } else { + m_lhs_array(index0_value, index1_value) = static_cast<ValueT>(std::get<DataT>(rhs)); + } + } else { + AffOp<OperatorT>().eval(m_lhs_array(index0_value, index1_value), std::get<DataT>(rhs)); + } + } + } + } +}; + +template <typename OperatorT, typename ArrayT, typename ValueT, typename DataT> +class VectorComponentAffectationExecutor final : public IAffectationExecutor { private: ArrayT& m_lhs_array; @@ -182,7 +307,7 @@ class ComponentAffectationExecutor final : public IAffectationExecutor }()}; public: - ComponentAffectationExecutor(ASTNode& node, ArrayT& lhs_array, ASTNode& index_expression) + VectorComponentAffectationExecutor(ASTNode& node, ArrayT& lhs_array, ASTNode& index_expression) : m_lhs_array{lhs_array}, m_index_expression{index_expression} { // LCOV_EXCL_START @@ -285,53 +410,99 @@ class AffectationProcessor final : public INodeProcessor Assert(found); DataVariant& value = i_symbol->attributes().value(); - // LCOV_EXCL_START - if (array_expression.m_data_type != ASTNodeDataType::vector_t) { - throw ParseError("unexpected error: invalid lhs (expecting R^d)", - std::vector{array_subscript_expression.begin()}); - } - // LCOV_EXCL_STOP + if (array_expression.m_data_type == ASTNodeDataType::vector_t) { + Assert(array_subscript_expression.children.size() == 2); - auto& index_expression = *array_subscript_expression.children[1]; + auto& index_expression = *array_subscript_expression.children[1]; - switch (array_expression.m_data_type.dimension()) { - case 1: { - using ArrayTypeT = TinyVector<1>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + switch (array_expression.m_data_type.dimension()) { + case 1: { + using ArrayTypeT = TinyVector<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = + std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor = - std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); - break; - } - case 2: { - using ArrayTypeT = TinyVector<2>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 2: { + using ArrayTypeT = TinyVector<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = + std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor = - std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); - break; - } - case 3: { - using ArrayTypeT = TinyVector<3>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 3: { + using ArrayTypeT = TinyVector<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = + std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor = - std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression); - break; - } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + Assert(array_subscript_expression.children.size() == 3); + Assert(array_expression.m_data_type.nbRows() == array_expression.m_data_type.nbColumns()); + + auto& index0_expression = *array_subscript_expression.children[1]; + auto& index1_expression = *array_subscript_expression.children[2]; + + switch (array_expression.m_data_type.nbRows()) { + case 1: { + using ArrayTypeT = TinyMatrix<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), + index0_expression, index1_expression); + break; + } + case 2: { + using ArrayTypeT = TinyMatrix<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), + index0_expression, index1_expression); + break; + } + case 3: { + using ArrayTypeT = TinyMatrix<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor = std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), + index0_expression, index1_expression); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP + } + } else { // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension", std::vector{array_subscript_expression.begin()}); - } + throw UnexpectedError("invalid subscript expression"); // LCOV_EXCL_STOP } - } else { // LCOV_EXCL_START throw ParseError("unexpected error: invalid lhs", std::vector{node.children[0]->begin()}); @@ -388,6 +559,55 @@ class AffectationToTinyVectorFromListProcessor final : public INodeProcessor }; template <typename OperatorT, typename ValueT> +class AffectationToTinyMatrixFromListProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + DataVariant* m_lhs; + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy)); + + static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to vectors"); + + ValueT v; + for (size_t i = 0, l = 0; i < v.nbRows(); ++i) { + for (size_t j = 0; j < v.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& child_value) { + using T = std::decay_t<decltype(child_value)>; + if constexpr (std::is_same_v<T, bool> or std::is_same_v<T, uint64_t> or std::is_same_v<T, int64_t> or + std::is_same_v<T, double>) { + v(i, j) = child_value; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", m_node.begin()); + // LCOV_EXCL_STOP + } + }, + children_values[l]); + } + } + + *m_lhs = v; + return {}; + } + + AffectationToTinyMatrixFromListProcessor(ASTNode& node) : m_node{node} + { + const std::string& symbol = m_node.children[0]->string(); + auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin()); + Assert(found); + + m_lhs = &i_symbol->attributes().value(); + } +}; + +template <typename ValueT> class AffectationToTupleProcessor final : public INodeProcessor { private: @@ -399,7 +619,6 @@ class AffectationToTupleProcessor final : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator to tuples"); DataVariant value = m_node.children[1]->execute(exec_policy); std::visit( @@ -414,11 +633,22 @@ class AffectationToTupleProcessor final : public INodeProcessor *m_lhs = std::vector{std::move(std::to_string(v))}; } else { std::ostringstream os; - os << v << std::ends; - *m_lhs = std::vector{os.str()}; + os << v; + *m_lhs = std::vector<std::string>{os.str()}; + } + } else if constexpr (is_tiny_vector_v<ValueT> or is_tiny_matrix_v<ValueT>) { + if constexpr (std::is_same_v<ValueT, TinyVector<1>> and std::is_arithmetic_v<T>) { + *m_lhs = std::vector<TinyVector<1>>{TinyVector<1>{static_cast<double>(v)}}; + } else if constexpr (std::is_same_v<ValueT, TinyMatrix<1>> and std::is_arithmetic_v<T>) { + *m_lhs = std::vector<TinyMatrix<1>>{TinyMatrix<1>{static_cast<double>(v)}}; + } else if constexpr (std::is_same_v<T, int64_t>) { + Assert(v == 0); + *m_lhs = std::vector<ValueT>{ValueT{zero}}; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", m_node.begin()); + // LCOV_EXCL_STOP } - } else if constexpr (std::is_same_v<ValueT, TinyVector<1>> and std::is_arithmetic_v<T>) { - *m_lhs = std::vector{TinyVector<1>{static_cast<double>(v)}}; } else { // LCOV_EXCL_START throw ParseError("unexpected error: unexpected right hand side type in affectation", m_node.begin()); @@ -440,7 +670,7 @@ class AffectationToTupleProcessor final : public INodeProcessor } }; -template <typename OperatorT, typename ValueT> +template <typename ValueT> class AffectationToTupleFromListProcessor final : public INodeProcessor { private: @@ -465,7 +695,7 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor tuple_value[i] = std::to_string(child_value); } else { std::ostringstream os; - os << child_value << std::ends; + os << child_value; tuple_value[i] = os.str(); } } else if constexpr (is_tiny_vector_v<ValueT>) { @@ -487,10 +717,49 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor }, child_value[j]); } - } else if constexpr (std::is_same_v<T, int64_t>) { - // in this case a 0 is given - Assert(child_value == 0); - tuple_value[i] = ZeroType{}; + } else if constexpr (std::is_arithmetic_v<T>) { + if constexpr (std::is_same_v<ValueT, TinyVector<1>>) { + tuple_value[i][0] = child_value; + } else { + // in this case a 0 is given + Assert(child_value == 0); + tuple_value[i] = ZeroType{}; + } + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", + m_node.children[1]->children[i]->begin()); + // LCOV_EXCL_STOP + } + } else if constexpr (is_tiny_matrix_v<ValueT>) { + if constexpr (std::is_same_v<T, AggregateDataVariant>) { + ValueT& A = tuple_value[i]; + Assert(A.nbRows() * A.nbColumns() == child_value.size()); + for (size_t j = 0, l = 0; j < A.nbRows(); ++j) { + for (size_t k = 0; k < A.nbColumns(); ++k, ++l) { + std::visit( + [&](auto&& Ajk) { + using Ti = std::decay_t<decltype(Ajk)>; + if constexpr (std::is_convertible_v<Ti, typename ValueT::data_type>) { + A(j, k) = Ajk; + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: unexpected right hand side type in affectation", + m_node.children[1]->children[i]->begin()); + // LCOV_EXCL_STOP + } + }, + child_value[l]); + } + } + } else if constexpr (std::is_arithmetic_v<T>) { + if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) { + tuple_value[i](0, 0) = child_value; + } else { + // in this case a 0 is given + Assert(child_value == 0); + tuple_value[i] = ZeroType{}; + } } else { // LCOV_EXCL_START throw ParseError("unexpected error: unexpected right hand side type in affectation", @@ -530,7 +799,7 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor } else { for (size_t i = 0; i < values.size(); ++i) { std::ostringstream sout; - sout << values[i] << std::ends; + sout << values[i]; v[i] = sout.str(); } } @@ -548,8 +817,6 @@ class AffectationToTupleFromListProcessor final : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to tuple"); - std::visit( [&](auto&& value_list) { using ValueListT = std::decay_t<decltype(value_list)>; @@ -643,51 +910,98 @@ class ListAffectationProcessor final : public INodeProcessor Assert(found); DataVariant& value = i_symbol->attributes().value(); - if (array_expression.m_data_type != ASTNodeDataType::vector_t) { - // LCOV_EXCL_START - throw ParseError("unexpected error: invalid lhs (expecting R^d)", - std::vector{array_subscript_expression.begin()}); - // LCOV_EXCL_STOP - } + if (array_expression.m_data_type == ASTNodeDataType::vector_t) { + Assert(array_subscript_expression.children.size() == 2); - auto& index_expression = *array_subscript_expression.children[1]; + auto& index_expression = *array_subscript_expression.children[1]; - switch (array_expression.m_data_type.dimension()) { - case 1: { - using ArrayTypeT = TinyVector<1>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + switch (array_expression.m_data_type.dimension()) { + case 1: { + using ArrayTypeT = TinyVector<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor_list.emplace_back( - std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); - break; - } - case 2: { - using ArrayTypeT = TinyVector<2>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 2: { + using ArrayTypeT = TinyVector<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); + break; } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor_list.emplace_back( - std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); - break; - } - case 3: { - using ArrayTypeT = TinyVector<3>; - if (not std::holds_alternative<ArrayTypeT>(value)) { - value = ArrayTypeT{}; + case 3: { + using ArrayTypeT = TinyVector<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP + } + } else if (array_expression.m_data_type == ASTNodeDataType::matrix_t) { + Assert(array_subscript_expression.children.size() == 3); + + auto& index0_expression = *array_subscript_expression.children[1]; + auto& index1_expression = *array_subscript_expression.children[2]; + + Assert(array_expression.m_data_type.nbRows() == array_expression.m_data_type.nbColumns()); + + switch (array_expression.m_data_type.nbRows()) { + case 1: { + using ArrayTypeT = TinyMatrix<1>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression, + index1_expression)); + break; + } + case 2: { + using ArrayTypeT = TinyMatrix<2>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression, + index1_expression)); + break; + } + case 3: { + using ArrayTypeT = TinyMatrix<3>; + if (not std::holds_alternative<ArrayTypeT>(value)) { + value = ArrayTypeT{}; + } + using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; + m_affectation_executor_list.emplace_back( + std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression, + index1_expression)); + break; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid vector dimension", + std::vector{array_subscript_expression.begin()}); + } + // LCOV_EXCL_STOP } - using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>; - m_affectation_executor_list.emplace_back( - std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression)); - break; - } - // LCOV_EXCL_START - default: { - throw ParseError("unexpected error: invalid vector dimension", std::vector{array_subscript_expression.begin()}); - } - // LCOV_EXCL_STOP } } else { // LCOV_EXCL_START diff --git a/src/language/node_processor/ArraySubscriptProcessor.hpp b/src/language/node_processor/ArraySubscriptProcessor.hpp index 3a71c01846b9e17ce28818d732716aea764d2496..43eb962b9d724b19b4cfe7ab808f52dd73d1fce2 100644 --- a/src/language/node_processor/ArraySubscriptProcessor.hpp +++ b/src/language/node_processor/ArraySubscriptProcessor.hpp @@ -15,9 +15,7 @@ class ArraySubscriptProcessor : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - auto& index_expression = *m_array_subscript_expression.children[1]; - - const int64_t index_value = [&](DataVariant&& value_variant) -> int64_t { + auto get_index_value = [&](DataVariant&& value_variant) -> int64_t { int64_t index_value = 0; std::visit( [&](auto&& value) { @@ -26,20 +24,39 @@ class ArraySubscriptProcessor : public INodeProcessor index_value = value; } else { // LCOV_EXCL_START - throw ParseError("unexpected error: invalid index type", std::vector{index_expression.begin()}); + throw UnexpectedError("invalid index type"); // LCOV_EXCL_STOP } }, value_variant); return index_value; - }(index_expression.execute(exec_policy)); + }; + + if constexpr (is_tiny_vector_v<ArrayTypeT>) { + auto& index_expression = *m_array_subscript_expression.children[1]; + + const int64_t index_value = get_index_value(index_expression.execute(exec_policy)); + + auto& array_expression = *m_array_subscript_expression.children[0]; + + auto&& array_value = array_expression.execute(exec_policy); + ArrayTypeT& array = std::get<ArrayTypeT>(array_value); + + return array[index_value]; + } else if constexpr (is_tiny_matrix_v<ArrayTypeT>) { + auto& index0_expression = *m_array_subscript_expression.children[1]; + auto& index1_expression = *m_array_subscript_expression.children[2]; + + const int64_t index0_value = get_index_value(index0_expression.execute(exec_policy)); + const int64_t index1_value = get_index_value(index1_expression.execute(exec_policy)); - auto& array_expression = *m_array_subscript_expression.children[0]; + auto& array_expression = *m_array_subscript_expression.children[0]; - auto&& array_value = array_expression.execute(exec_policy); - ArrayTypeT& array = std::get<ArrayTypeT>(array_value); + auto&& array_value = array_expression.execute(exec_policy); + ArrayTypeT& array = std::get<ArrayTypeT>(array_value); - return array[index_value]; + return array(index0_value, index1_value); + } } ArraySubscriptProcessor(ASTNode& array_subscript_expression) diff --git a/src/language/node_processor/BinaryExpressionProcessor.hpp b/src/language/node_processor/BinaryExpressionProcessor.hpp index 8f5d9c3220cf3018104d5d53e2036b9bb8217214..cfbff72617d7b99e68ea9693eca96fa718c3f918 100644 --- a/src/language/node_processor/BinaryExpressionProcessor.hpp +++ b/src/language/node_processor/BinaryExpressionProcessor.hpp @@ -6,6 +6,8 @@ #include <language/node_processor/INodeProcessor.hpp> #include <language/utils/ParseError.hpp> +#include <type_traits> + template <typename Op> struct BinOp; @@ -154,68 +156,49 @@ struct BinOp<language::divide_op> } }; -template <typename BinaryOpT, typename A_DataT, typename B_DataT> -class BinaryExpressionProcessor final : public INodeProcessor +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor final : public INodeProcessor { + private: ASTNode& m_node; PUGS_INLINE DataVariant _eval(const DataVariant& a, const DataVariant& b) { - // Add 'signed' when necessary to avoid signed/unsigned comparison warnings - if constexpr ((not(std::is_same_v<A_DataT, bool> or std::is_same_v<B_DataT, bool>)) and - (std::is_same_v<BinaryOpT, language::and_op> or std::is_same_v<BinaryOpT, language::or_op> or - std::is_same_v<BinaryOpT, language::xor_op> or std::is_same_v<BinaryOpT, language::eqeq_op> or - std::is_same_v<BinaryOpT, language::not_eq_op> or std::is_same_v<BinaryOpT, language::lesser_op> or - std::is_same_v<BinaryOpT, language::lesser_or_eq_op> or - std::is_same_v<BinaryOpT, language::greater_op> or - std::is_same_v<BinaryOpT, language::greater_or_eq_op>) and - (std::is_signed_v<A_DataT> xor std::is_signed_v<B_DataT>)) { - if constexpr (std::is_unsigned_v<A_DataT>) { - using signed_A_DataT = std::make_signed_t<A_DataT>; - const signed_A_DataT signed_a = static_cast<signed_A_DataT>(std::get<A_DataT>(a)); - return BinOp<BinaryOpT>().eval(signed_a, std::get<B_DataT>(b)); + if constexpr (std::is_arithmetic_v<A_DataT> and std::is_arithmetic_v<B_DataT>) { + if constexpr (std::is_signed_v<A_DataT> and not std::is_signed_v<B_DataT>) { + if constexpr (std::is_same_v<B_DataT, bool>) { + return static_cast<ValueT>( + BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), static_cast<int64_t>(std::get<B_DataT>(b)))); + } else { + return static_cast<ValueT>( + BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), std::make_signed_t<B_DataT>(std::get<B_DataT>(b)))); + } + + } else if constexpr (not std::is_signed_v<A_DataT> and std::is_signed_v<B_DataT>) { + if constexpr (std::is_same_v<A_DataT, bool>) { + return static_cast<ValueT>( + BinOp<BinaryOpT>().eval(static_cast<int64_t>(std::get<A_DataT>(a)), std::get<B_DataT>(b))); + } else { + return static_cast<ValueT>( + BinOp<BinaryOpT>().eval(std::make_signed_t<A_DataT>(std::get<A_DataT>(a)), std::get<B_DataT>(b))); + } } else { - using signed_B_DataT = std::make_signed_t<B_DataT>; - const signed_B_DataT signed_b = static_cast<signed_B_DataT>(std::get<B_DataT>(b)); - return BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), signed_b); + return static_cast<ValueT>(BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), std::get<B_DataT>(b))); } } else { - auto result = BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), std::get<B_DataT>(b)); - if constexpr (std::is_same_v<decltype(result), int>) { - return static_cast<int64_t>(result); - } else { - return result; - } + return static_cast<ValueT>(BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), std::get<B_DataT>(b))); } } - static inline const bool m_is_defined{[] { - if constexpr (std::is_same_v<BinaryOpT, language::xor_op>) { - return std::is_same_v<std::decay_t<A_DataT>, std::decay_t<B_DataT>> and std::is_integral_v<std::decay_t<A_DataT>>; - } - return true; - }()}; - public: DataVariant execute(ExecutionPolicy& exec_policy) { - if constexpr (m_is_defined) { - return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); - } else { - return {}; // LCOV_EXCL_LINE - } + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); } - BinaryExpressionProcessor(ASTNode& node) : m_node{node} - { - if constexpr (not m_is_defined) { - // LCOV_EXCL_START - throw ParseError("invalid operands to binary expression", std::vector{m_node.begin()}); - // LCOV_EXCL_STOP - } - } + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} }; #endif // BINARY_EXPRESSION_PROCESSOR_HPP diff --git a/src/language/node_processor/BuiltinFunctionProcessor.hpp b/src/language/node_processor/BuiltinFunctionProcessor.hpp index 302cbaf2d2ce3570678fb62a57d4e463e47a2432..4e5583adb2d269e7552ff3cccf7dfcbe811b0791 100644 --- a/src/language/node_processor/BuiltinFunctionProcessor.hpp +++ b/src/language/node_processor/BuiltinFunctionProcessor.hpp @@ -64,7 +64,9 @@ class BuiltinFunctionProcessor : public INodeProcessor } if (SignalManager::pauseOnError()) { + // LCOV_EXCL_START return m_function_expression_processor->execute(context_exec_policy); + // LCOV_EXCL_STOP } else { try { return m_function_expression_processor->execute(context_exec_policy); diff --git a/src/language/node_processor/ConcatExpressionProcessor.hpp b/src/language/node_processor/ConcatExpressionProcessor.hpp index e47178494fe2bf88f8c7e5bebf9455c58732b08e..bb6357c87124110e973730ff6fb3eb9d5fdde5a3 100644 --- a/src/language/node_processor/ConcatExpressionProcessor.hpp +++ b/src/language/node_processor/ConcatExpressionProcessor.hpp @@ -16,8 +16,12 @@ class ConcatExpressionProcessor final : public INodeProcessor { if constexpr (std::is_same_v<B_DataT, std::string>) { return a + std::get<B_DataT>(b); - } else { + } else if constexpr (std::is_arithmetic_v<B_DataT>) { return a + std::to_string(std::get<B_DataT>(b)); + } else { + std::ostringstream os; + os << a << b; + return os.str(); } } diff --git a/src/language/node_processor/DoWhileProcessor.hpp b/src/language/node_processor/DoWhileProcessor.hpp index 5b03b769ffe3e376c5e8bddafa91db5d01e2f59f..6f3c14b7f8706c6d5480f792e449c1f110b74498 100644 --- a/src/language/node_processor/DoWhileProcessor.hpp +++ b/src/language/node_processor/DoWhileProcessor.hpp @@ -3,6 +3,7 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/SymbolTable.hpp> class DoWhileProcessor final : public INodeProcessor { diff --git a/src/language/node_processor/ForProcessor.hpp b/src/language/node_processor/ForProcessor.hpp index 8680aaf7ded0f7fabbc345682a9a32c772a44f08..4199e0c42867398167c2f1fa5927d87e7f4fddc6 100644 --- a/src/language/node_processor/ForProcessor.hpp +++ b/src/language/node_processor/ForProcessor.hpp @@ -3,6 +3,7 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/SymbolTable.hpp> class ForProcessor final : public INodeProcessor { diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index 4bf7e56eebfd93b244d13a774e2e69aac2834794..af08e7806a4fd5daa01a7aecc98a84e52291040f 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -31,9 +31,20 @@ class FunctionArgumentToStringConverter final : public IFunctionArgumentConverte DataVariant convert(ExecutionPolicy& exec_policy, DataVariant&& value) { - std::ostringstream sout; - sout << value; - exec_policy.currentContext()[m_argument_id] = sout.str(); + std::visit( + [&](auto&& v) { + using T = std::decay_t<decltype(v)>; + if constexpr (std::is_arithmetic_v<T>) { + exec_policy.currentContext()[m_argument_id] = std::to_string(v); + } else if constexpr (std::is_same_v<T, std::string>) { + exec_policy.currentContext()[m_argument_id] = v; + } else { + std::ostringstream sout; + sout << value; + exec_policy.currentContext()[m_argument_id] = sout.str(); + } + }, + value); return {}; } @@ -100,6 +111,7 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; } else { + static_assert(std::is_same_v<ExpectedValueType, TinyVector<1>>); exec_policy.currentContext()[m_argument_id] = std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); } @@ -109,6 +121,55 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver FunctionTinyVectorArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} }; +template <typename ExpectedValueType, typename ProvidedValueType> +class FunctionTinyMatrixArgumentConverter final : public IFunctionArgumentConverter +{ + private: + size_t m_argument_id; + + public: + DataVariant + convert(ExecutionPolicy& exec_policy, DataVariant&& value) + { + if constexpr (std::is_same_v<ExpectedValueType, ProvidedValueType>) { + std::visit( + [&](auto&& v) { + using ValueT = std::decay_t<decltype(v)>; + if constexpr (std::is_same_v<ValueT, ExpectedValueType>) { + exec_policy.currentContext()[m_argument_id] = std::move(value); + } else if constexpr (std::is_same_v<ValueT, AggregateDataVariant>) { + ExpectedValueType matrix_value{}; + for (size_t i = 0, l = 0; i < matrix_value.nbRows(); ++i) { + for (size_t j = 0; j < matrix_value.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& A_ij) { + using Vi_T = std::decay_t<decltype(A_ij)>; + if constexpr (std::is_arithmetic_v<Vi_T>) { + matrix_value(i, j) = A_ij; + } else { + throw UnexpectedError(demangle<Vi_T>() + " unexpected aggregate value type"); + } + }, + v[l]); + } + } + exec_policy.currentContext()[m_argument_id] = std::move(matrix_value); + } + }, + value); + } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; + } else { + static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>); + exec_policy.currentContext()[m_argument_id] = + std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + } + return {}; + } + + FunctionTinyMatrixArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} +}; + template <typename ContentType, typename ProvidedValueType> class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter { @@ -132,19 +193,38 @@ class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter TupleType list_value; list_value.reserve(v.size()); for (size_t i = 0; i < v.size(); ++i) { - list_value.emplace_back(v[i]); + list_value.emplace_back(std::move(v[i])); + } + exec_policy.currentContext()[m_argument_id] = std::move(list_value); + } else if constexpr ((std::is_convertible_v<ContentT, ContentType>)and not is_tiny_vector_v<ContentType> and + not is_tiny_matrix_v<ContentType>) { + TupleType list_value; + list_value.reserve(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + list_value.push_back(static_cast<ContentType>(v[i])); } exec_policy.currentContext()[m_argument_id] = std::move(list_value); + } else { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ValueT>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } - } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType>) { + } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType> and + not is_tiny_matrix_v<ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { - throw UnexpectedError(demangle<ValueT>() + " unexpected value type"); + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ValueT>() + "' to '" + + demangle<ContentType>() + "'"); } }, value); + } else { - throw UnexpectedError(demangle<std::decay_t<decltype(*this)>>() + ": did nothing!"); + // LCOV_EXCL_START + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ProvidedValueType>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } return {}; } @@ -174,31 +254,44 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter std::visit( [&](auto&& vi) { using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (is_tiny_vector_v<ContentType>) { - throw NotImplementedError("TinyVector case"); + if constexpr (std::is_same_v<Vi_T, ContentType>) { + list_value.emplace_back(vi); + } else if constexpr (is_tiny_vector_v<ContentType> or is_tiny_matrix_v<ContentType>) { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<Vi_T>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } else if constexpr (std::is_convertible_v<Vi_T, ContentType>) { list_value.emplace_back(vi); } else { + // LCOV_EXCL_START throw UnexpectedError("unexpected types"); + // LCOV_EXCL_STOP } }, (v[i])); } exec_policy.currentContext()[m_argument_id] = std::move(list_value); - } else if constexpr (std::is_same_v<ValueT, ContentType>) { - exec_policy.currentContext()[m_argument_id] = std::move(v); } else if constexpr (is_std_vector_v<ValueT>) { using ContentT = typename ValueT::value_type; if constexpr (std::is_same_v<ContentT, ContentType>) { - TupleType list_value; - list_value.reserve(v.size()); - for (size_t i = 0; i < v.size(); ++i) { - list_value.emplace_back(v[i]); - } - exec_policy.currentContext()[m_argument_id] = std::move(list_value); + exec_policy.currentContext()[m_argument_id] = v; + } else { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<ContentT>() + "' to '" + + demangle<ContentType>() + "'"); + // LCOV_EXCL_STOP } + } else if constexpr (std::is_same_v<ValueT, ContentType>) { + exec_policy.currentContext()[m_argument_id] = std::move(TupleType{v}); + } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ValueT> and + not is_tiny_vector_v<ContentType> and not is_tiny_matrix_v<ValueT> and + not is_tiny_matrix_v<ContentType>) { + exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { + // LCOV_EXCL_START throw UnexpectedError(demangle<ValueT>() + " unexpected value type"); + // LCOV_EXCL_STOP } }, value); diff --git a/src/language/node_processor/FunctionProcessor.hpp b/src/language/node_processor/FunctionProcessor.hpp index b28c2ce499538818ce60b1ac36a1d158114409de..f8c86ea9e5cb51e169e48ecadc0b4770ef873cc8 100644 --- a/src/language/node_processor/FunctionProcessor.hpp +++ b/src/language/node_processor/FunctionProcessor.hpp @@ -21,18 +21,35 @@ class FunctionExpressionProcessor final : public INodeProcessor if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) { return m_function_expression.execute(exec_policy); } else if constexpr (std::is_same_v<AggregateDataVariant, ExpressionValueType>) { - static_assert(is_tiny_vector_v<ReturnType>, "unexpected return type"); + static_assert(is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>, "unexpected return type"); ReturnType return_value{}; auto value = std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)); - for (size_t i = 0; i < ReturnType::Dimension; ++i) { - std::visit( - [&](auto&& vi) { - using Vi_T = std::decay_t<decltype(vi)>; - if constexpr (std::is_convertible_v<Vi_T, double>) { - return_value[i] = vi; - } - }, - value[i]); + if constexpr (is_tiny_vector_v<ReturnType>) { + for (size_t i = 0; i < ReturnType::Dimension; ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + return_value[i] = vi; + } + }, + value[i]); + } + } else { + static_assert(is_tiny_matrix_v<ReturnType>); + + for (size_t i = 0, l = 0; i < return_value.nbRows(); ++i) { + for (size_t j = 0; j < return_value.nbColumns(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Vi_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_convertible_v<Vi_T, double>) { + return_value(i, j) = Aij; + } + }, + value[l]); + } + } } return return_value; } else if constexpr (std::is_same_v<ReturnType, std::string>) { diff --git a/src/language/node_processor/INodeProcessor.hpp b/src/language/node_processor/INodeProcessor.hpp index ce06b9484f7c7b141861bddfc88f5b5c3bf9eecc..2c35b6c49cadd69536c415f5a86e5786c953a86c 100644 --- a/src/language/node_processor/INodeProcessor.hpp +++ b/src/language/node_processor/INodeProcessor.hpp @@ -8,8 +8,9 @@ #include <string> #include <typeinfo> -struct INodeProcessor +class INodeProcessor { + public: virtual DataVariant execute(ExecutionPolicy& exec_policy) = 0; std::string diff --git a/src/language/node_processor/IfProcessor.hpp b/src/language/node_processor/IfProcessor.hpp index 1a76e3068dcb6bd29f4e4e250dc3d6cc1f9f6d41..18d87bb89f11ee0f335fe0637f4569dc5b3cab15 100644 --- a/src/language/node_processor/IfProcessor.hpp +++ b/src/language/node_processor/IfProcessor.hpp @@ -3,6 +3,7 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/SymbolTable.hpp> class IfProcessor final : public INodeProcessor { @@ -39,8 +40,7 @@ class IfProcessor final : public INodeProcessor } } - if (m_node.children[0]->m_symbol_table != m_node.m_symbol_table) - m_node.children[0]->m_symbol_table->clearValues(); + Assert(m_node.children[0]->m_symbol_table == m_node.m_symbol_table); return {}; } diff --git a/src/language/node_processor/TupleToTinyMatrixProcessor.hpp b/src/language/node_processor/TupleToTinyMatrixProcessor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..77bc45e8419f7d2e58447247aef3975c337813c5 --- /dev/null +++ b/src/language/node_processor/TupleToTinyMatrixProcessor.hpp @@ -0,0 +1,53 @@ +#ifndef TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP +#define TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP + +#include <language/ast/ASTNode.hpp> +#include <language/node_processor/INodeProcessor.hpp> + +template <typename TupleProcessorT, size_t N> +class TupleToTinyMatrixProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + std::unique_ptr<TupleProcessorT> m_tuple_processor; + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + AggregateDataVariant v = std::get<AggregateDataVariant>(m_tuple_processor->execute(exec_policy)); + + Assert(v.size() == N * N); + + TinyMatrix<N> A; + + for (size_t i = 0, l = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using ValueT = std::decay_t<decltype(Aij)>; + if constexpr (std::is_arithmetic_v<ValueT>) { + A(i, j) = Aij; + } else { + // LCOV_EXCL_START + Assert(false, "unexpected value type"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + + return DataVariant{std::move(A)}; + } + + TupleToTinyMatrixProcessor(ASTNode& node) : m_node{node}, m_tuple_processor{std::make_unique<TupleProcessorT>(node)} + {} + + TupleToTinyMatrixProcessor(ASTNode& node, std::unique_ptr<TupleProcessorT>&& tuple_processor) + : m_node{node}, m_tuple_processor{std::move(tuple_processor)} + {} +}; + +#endif // TUPLE_TO_TINY_MATRIX_PROCESSOR_HPP diff --git a/src/language/node_processor/UnaryExpressionProcessor.hpp b/src/language/node_processor/UnaryExpressionProcessor.hpp index c99c7f2dd828583c7dad8048338e68d26c807be2..055e3a028be043c8a30189b7e4cccddd99aeef4a 100644 --- a/src/language/node_processor/UnaryExpressionProcessor.hpp +++ b/src/language/node_processor/UnaryExpressionProcessor.hpp @@ -2,8 +2,8 @@ #define UNARY_EXPRESSION_PROCESSOR_HPP #include <language/PEGGrammar.hpp> +#include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> -#include <language/utils/SymbolTable.hpp> template <typename Op> struct UnaryOp; diff --git a/src/language/node_processor/WhileProcessor.hpp b/src/language/node_processor/WhileProcessor.hpp index 20b0ac985ae43bfaf3f9b165f9a346aa82a9f3cb..af7cc604bb6c2d5f74cb894a1f05ba20acfb35f6 100644 --- a/src/language/node_processor/WhileProcessor.hpp +++ b/src/language/node_processor/WhileProcessor.hpp @@ -3,6 +3,7 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/SymbolTable.hpp> class WhileProcessor final : public INodeProcessor { diff --git a/src/language/utils/ASTExecutionInfo.cpp b/src/language/utils/ASTExecutionInfo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6015b6a8f03625d436916224566725eaa7a18e2d --- /dev/null +++ b/src/language/utils/ASTExecutionInfo.cpp @@ -0,0 +1,25 @@ +#include <language/utils/ASTExecutionInfo.hpp> + +#include <language/ast/ASTNode.hpp> + +const ASTExecutionInfo* ASTExecutionInfo::m_current_execution_info = nullptr; + +ASTExecutionInfo::ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_repository) + : m_root_node{root_node}, m_module_repository{module_repository} +{ + Assert(m_current_execution_info == nullptr, "Can only define one ASTExecutionInfo"); + + m_current_execution_info = this; +} + +const ASTExecutionInfo& +ASTExecutionInfo::current() +{ + Assert(m_current_execution_info != nullptr, "ASTExecutionInfo is not defined!"); + return *m_current_execution_info; +} + +ASTExecutionInfo::~ASTExecutionInfo() +{ + m_current_execution_info = nullptr; +} diff --git a/src/language/utils/ASTExecutionInfo.hpp b/src/language/utils/ASTExecutionInfo.hpp new file mode 100644 index 0000000000000000000000000000000000000000..da4b817428c2ea00e8c089e14de2049c4faae4af --- /dev/null +++ b/src/language/utils/ASTExecutionInfo.hpp @@ -0,0 +1,44 @@ +#ifndef AST_EXECUTION_INFO_HPP +#define AST_EXECUTION_INFO_HPP + +#include <string> + +class ModuleRepository; +class ASTNode; +class ASTExecutionInfo +{ + private: + static const ASTExecutionInfo* m_current_execution_info; + + const ASTNode& m_root_node; + + const ModuleRepository& m_module_repository; + + // The only place where the ASTExecutionInfo can be built + friend void parser(const std::string& filename); + // also allowed for testing + friend void test_ASTExecutionInfo(const ASTNode&, const ModuleRepository&); + + ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_repository); + + public: + const ASTNode& + rootNode() const + { + return m_root_node; + } + + const ModuleRepository& + moduleRepository() const + { + return m_module_repository; + } + + static const ASTExecutionInfo& current(); + + ASTExecutionInfo() = delete; + + ~ASTExecutionInfo(); +}; + +#endif // AST_EXECUTION_INFO_HPP diff --git a/src/language/ast/ASTNodeDataType.cpp b/src/language/utils/ASTNodeDataType.cpp similarity index 58% rename from src/language/ast/ASTNodeDataType.cpp rename to src/language/utils/ASTNodeDataType.cpp index c77659d983c58a571ac3dd8426437be6478ceddb..81f3fc81caba045ca4ae719ddfa21c09c1c8fc98 100644 --- a/src/language/ast/ASTNodeDataType.cpp +++ b/src/language/utils/ASTNodeDataType.cpp @@ -1,4 +1,4 @@ -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> @@ -16,7 +16,40 @@ getVectorDataType(const ASTNode& type_node) throw ParseError("unexpected non integer constant dimension", dimension_node.begin()); } const size_t dimension = std::stol(dimension_node.string()); - return ASTNodeDataType{ASTNodeDataType::vector_t, dimension}; + if (not(dimension > 0 and dimension <= 3)) { + throw ParseError("invalid dimension (must be 1, 2 or 3)", dimension_node.begin()); + } + return ASTNodeDataType::build<ASTNodeDataType::vector_t>(dimension); +} + +ASTNodeDataType +getMatrixDataType(const ASTNode& type_node) +{ + if (not(type_node.is_type<language::matrix_type>() and (type_node.children.size() == 3))) { + throw ParseError("unexpected node type", type_node.begin()); + } + + ASTNode& dimension0_node = *type_node.children[1]; + if (not dimension0_node.is_type<language::integer>()) { + throw ParseError("unexpected non integer constant dimension", dimension0_node.begin()); + } + const size_t dimension0 = std::stol(dimension0_node.string()); + + ASTNode& dimension1_node = *type_node.children[2]; + if (not dimension1_node.is_type<language::integer>()) { + throw ParseError("unexpected non integer constant dimension", dimension1_node.begin()); + } + const size_t dimension1 = std::stol(dimension1_node.string()); + + if (dimension0 != dimension1) { + throw ParseError("only square matrices are supported", type_node.begin()); + } + + if (not(dimension0 > 0 and dimension0 <= 3)) { + throw ParseError("invalid dimension (must be 1, 2 or 3)", dimension0_node.begin()); + } + + return ASTNodeDataType::build<ASTNodeDataType::matrix_t>(dimension0, dimension1); } std::string @@ -42,17 +75,31 @@ dataTypeName(const ASTNodeDataType& data_type) case ASTNodeDataType::vector_t: name = "R^" + std::to_string(data_type.dimension()); break; + case ASTNodeDataType::matrix_t: + name = "R^" + std::to_string(data_type.nbRows()) + "x" + std::to_string(data_type.nbColumns()); + break; case ASTNodeDataType::tuple_t: name = "tuple(" + dataTypeName(data_type.contentType()) + ')'; break; - case ASTNodeDataType::list_t: - name = "list"; + case ASTNodeDataType::list_t: { + std::ostringstream data_type_name_list; + const auto& data_type_list = data_type.contentTypeList(); + if (data_type_list.size() > 0) { + data_type_name_list << dataTypeName(*data_type_list[0]); + for (size_t i = 1; i < data_type_list.size(); ++i) { + data_type_name_list << '*' << dataTypeName(*data_type_list[i]); + } + name = "list(" + data_type_name_list.str() + ")"; + } else { + name = "list(void)"; + } break; + } case ASTNodeDataType::string_t: name = "string"; break; case ASTNodeDataType::typename_t: - name = "typename"; + name = std::string("typename(") + dataTypeName(data_type.contentType()) + ")"; break; case ASTNodeDataType::type_name_id_t: name = "type_name_id"; @@ -73,6 +120,24 @@ dataTypeName(const ASTNodeDataType& data_type) return name; } +std::string +dataTypeName(const std::vector<ASTNodeDataType>& data_type_vector) +{ + if (data_type_vector.size() == 0) { + return dataTypeName(ASTNodeDataType::build<ASTNodeDataType::void_t>()); + } else if (data_type_vector.size() == 1) { + return dataTypeName(data_type_vector[0]); + } else { + std::ostringstream os; + os << '(' << dataTypeName(data_type_vector[0]); + for (size_t i = 1; i < data_type_vector.size(); ++i) { + os << ',' << dataTypeName(data_type_vector[i]); + } + os << ')'; + return os.str(); + } +} + ASTNodeDataType dataTypePromotion(const ASTNodeDataType& data_type_1, const ASTNodeDataType& data_type_2) { @@ -88,7 +153,7 @@ dataTypePromotion(const ASTNodeDataType& data_type_1, const ASTNodeDataType& dat (data_type_2 == ASTNodeDataType::vector_t)) { return data_type_2; } else { - return ASTNodeDataType::undefined_t; + return ASTNodeDataType{}; } } @@ -100,6 +165,9 @@ isNaturalConversion(const ASTNodeDataType& data_type, const ASTNodeDataType& tar return (data_type.nameOfTypeId() == target_data_type.nameOfTypeId()); } else if (data_type == ASTNodeDataType::vector_t) { return (data_type.dimension() == target_data_type.dimension()); + } else if (data_type == ASTNodeDataType::matrix_t) { + return ((data_type.nbRows() == target_data_type.nbRows()) and + (data_type.nbColumns() == target_data_type.nbColumns())); } else { return true; } diff --git a/src/language/utils/ASTNodeDataType.hpp b/src/language/utils/ASTNodeDataType.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c31c9e15b6e8df0055aa4854bb7ede0a73318abd --- /dev/null +++ b/src/language/utils/ASTNodeDataType.hpp @@ -0,0 +1,211 @@ +#ifndef AST_NODE_DATA_TYPE_HPP +#define AST_NODE_DATA_TYPE_HPP + +#include <utils/PugsAssert.hpp> + +#include <limits> +#include <memory> +#include <string> +#include <variant> +#include <vector> + +class ASTNode; +class ASTNodeDataType; + +ASTNodeDataType getVectorDataType(const ASTNode& type_node); + +ASTNodeDataType getMatrixDataType(const ASTNode& type_node); + +std::string dataTypeName(const std::vector<ASTNodeDataType>& data_type_vector); + +std::string dataTypeName(const ASTNodeDataType& data_type); + +ASTNodeDataType dataTypePromotion(const ASTNodeDataType& data_type_1, const ASTNodeDataType& data_type_2); + +bool isNaturalConversion(const ASTNodeDataType& data_type, const ASTNodeDataType& target_data_type); + +class ASTNodeDataType +{ + public: + enum DataType : int32_t + { + undefined_t = -1, + bool_t = 0, + int_t = 1, + unsigned_int_t = 2, + double_t = 3, + vector_t = 4, + matrix_t = 5, + tuple_t = 6, + list_t = 7, + string_t = 8, + typename_t = 10, + type_name_id_t = 11, + type_id_t = 21, + function_t = 22, + builtin_function_t = 23, + void_t = std::numeric_limits<int32_t>::max() + }; + + private: + DataType m_data_type; + + using DataTypeDetails = std::variant<std::monostate, + size_t, + std::array<size_t, 2>, + std::string, + std::shared_ptr<const ASTNodeDataType>, + std::vector<std::shared_ptr<const ASTNodeDataType>>>; + + DataTypeDetails m_details; + + public: + PUGS_INLINE + size_t + dimension() const + { + Assert(std::holds_alternative<size_t>(m_details)); + return std::get<size_t>(m_details); + } + + PUGS_INLINE + size_t + nbRows() const + { + Assert(std::holds_alternative<std::array<size_t, 2>>(m_details)); + return std::get<std::array<size_t, 2>>(m_details)[0]; + } + + PUGS_INLINE + size_t + nbColumns() const + { + Assert(std::holds_alternative<std::array<size_t, 2>>(m_details)); + return std::get<std::array<size_t, 2>>(m_details)[1]; + } + + PUGS_INLINE + const std::string& + nameOfTypeId() const + { + Assert(std::holds_alternative<std::string>(m_details)); + return std::get<std::string>(m_details); + } + + PUGS_INLINE + const ASTNodeDataType& + contentType() const + { + Assert(std::holds_alternative<std::shared_ptr<const ASTNodeDataType>>(m_details)); + return *std::get<std::shared_ptr<const ASTNodeDataType>>(m_details); + } + + PUGS_INLINE + const std::vector<std::shared_ptr<const ASTNodeDataType>>& + contentTypeList() const + { + Assert(std::holds_alternative<std::vector<std::shared_ptr<const ASTNodeDataType>>>(m_details)); + return std::get<std::vector<std::shared_ptr<const ASTNodeDataType>>>(m_details); + } + + PUGS_INLINE + operator const DataType&() const + { + return m_data_type; + } + + ASTNodeDataType& operator=(const ASTNodeDataType&) = default; + ASTNodeDataType& operator=(ASTNodeDataType&&) = default; + + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build() + { + static_assert(data_type != tuple_t, "tuple_t requires sub_type"); + static_assert(data_type != typename_t, "typename_t requires sub_type"); + static_assert(data_type != vector_t, "vector_t requires dimension"); + static_assert(data_type != type_id_t, "type_id_t requires name"); + static_assert(data_type != list_t, "list_t requires list of types"); + + return ASTNodeDataType{data_type}; + } + + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build(const ASTNodeDataType& content_type) + { + static_assert((data_type == tuple_t) or (data_type == typename_t), + "incorrect data_type construction: cannot have content"); + Assert(content_type != ASTNodeDataType::undefined_t); + + return ASTNodeDataType{data_type, content_type}; + } + + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build(const size_t dimension) + { + static_assert((data_type == vector_t), "incorrect data_type construction: cannot have dimension"); + return ASTNodeDataType{data_type, dimension}; + } + + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build(const size_t nb_rows, const size_t nb_columns) + { + static_assert((data_type == matrix_t), "incorrect data_type construction: cannot have dimension"); + return ASTNodeDataType{data_type, nb_rows, nb_columns}; + } + + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build(const std::string& type_name) + { + static_assert((data_type == type_id_t), "incorrect data_type construction: cannot provide name of type"); + return ASTNodeDataType{data_type, type_name}; + } + + template <DataType data_type> + [[nodiscard]] static ASTNodeDataType + build(const std::vector<std::shared_ptr<const ASTNodeDataType>>& list_of_types) + { + static_assert((data_type == list_t), "incorrect data_type construction: cannot provide a list of data types"); + + for (auto i : list_of_types) { + Assert(i->m_data_type != ASTNodeDataType::undefined_t, "cannot build a type list containing undefined types"); + } + + return ASTNodeDataType{data_type, list_of_types}; + } + + ASTNodeDataType() : m_data_type{undefined_t} {} + + ASTNodeDataType(const ASTNodeDataType&) = default; + + ASTNodeDataType(ASTNodeDataType&&) = default; + + ~ASTNodeDataType() = default; + + private: + explicit ASTNodeDataType(DataType data_type) : m_data_type{data_type} {} + + explicit ASTNodeDataType(DataType data_type, const ASTNodeDataType& content_type) + : m_data_type{data_type}, m_details{std::make_shared<const ASTNodeDataType>(content_type)} + {} + + explicit ASTNodeDataType(DataType data_type, const std::vector<std::shared_ptr<const ASTNodeDataType>>& list_of_types) + : m_data_type{data_type}, m_details{list_of_types} + {} + + explicit ASTNodeDataType(DataType data_type, const size_t dimension) : m_data_type{data_type}, m_details{dimension} {} + + explicit ASTNodeDataType(DataType data_type, const size_t nb_rows, const size_t nb_columns) + : m_data_type{data_type}, m_details{std::array{nb_rows, nb_columns}} + {} + + explicit ASTNodeDataType(DataType data_type, const std::string& type_name) + : m_data_type{data_type}, m_details{type_name} + {} +}; + +#endif // AST_NODE_DATA_TYPE_HPP diff --git a/src/language/utils/ASTNodeDataTypeTraits.hpp b/src/language/utils/ASTNodeDataTypeTraits.hpp index 4ee3ba31bf60d3ca4f006fd0510fae0c1fc14056..2ba019646dac53291dc73169fadc47e737f68bad 100644 --- a/src/language/utils/ASTNodeDataTypeTraits.hpp +++ b/src/language/utils/ASTNodeDataTypeTraits.hpp @@ -1,34 +1,38 @@ -#ifndef AST_NODE_DATA_TYPE_TRAITS_H -#define AST_NODE_DATA_TYPE_TRAITS_H +#ifndef AST_NODE_DATA_TYPE_TRAITS_HPP +#define AST_NODE_DATA_TYPE_TRAITS_HPP +#include <algebra/TinyMatrix.hpp> #include <algebra/TinyVector.hpp> -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/utils/FunctionSymbolId.hpp> #include <vector> template <typename T> -inline ASTNodeDataType ast_node_data_type_from = ASTNodeDataType::undefined_t; +inline ASTNodeDataType ast_node_data_type_from = ASTNodeDataType{}; template <> -inline ASTNodeDataType ast_node_data_type_from<void> = ASTNodeDataType::void_t; +inline ASTNodeDataType ast_node_data_type_from<void> = ASTNodeDataType::build<ASTNodeDataType::void_t>(); template <> -inline ASTNodeDataType ast_node_data_type_from<bool> = ASTNodeDataType::bool_t; +inline ASTNodeDataType ast_node_data_type_from<bool> = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); template <> -inline ASTNodeDataType ast_node_data_type_from<int64_t> = ASTNodeDataType::int_t; +inline ASTNodeDataType ast_node_data_type_from<int64_t> = ASTNodeDataType::build<ASTNodeDataType::int_t>(); template <> -inline ASTNodeDataType ast_node_data_type_from<uint64_t> = ASTNodeDataType::unsigned_int_t; +inline ASTNodeDataType ast_node_data_type_from<uint64_t> = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); template <> -inline ASTNodeDataType ast_node_data_type_from<double> = ASTNodeDataType::double_t; +inline ASTNodeDataType ast_node_data_type_from<double> = ASTNodeDataType::build<ASTNodeDataType::double_t>(); template <> -inline ASTNodeDataType ast_node_data_type_from<std::string> = ASTNodeDataType::string_t; +inline ASTNodeDataType ast_node_data_type_from<std::string> = ASTNodeDataType::build<ASTNodeDataType::string_t>(); template <> -inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = ASTNodeDataType::function_t; +inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = + ASTNodeDataType::build<ASTNodeDataType::function_t>(); template <size_t N> -inline ASTNodeDataType ast_node_data_type_from<TinyVector<N>> = {ASTNodeDataType::vector_t, N}; +inline ASTNodeDataType ast_node_data_type_from<TinyVector<N>> = ASTNodeDataType::build<ASTNodeDataType::vector_t>(N); +template <size_t N> +inline ASTNodeDataType ast_node_data_type_from<TinyMatrix<N>> = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(N, N); template <typename T> inline ASTNodeDataType ast_node_data_type_from<std::vector<T>> = - ASTNodeDataType{ASTNodeDataType::tuple_t, ast_node_data_type_from<T>}; + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ast_node_data_type_from<T>); -#endif // AST_NODE_DATA_TYPE_TRAITS_H +#endif // AST_NODE_DATA_TYPE_TRAITS_HPP diff --git a/src/language/utils/ASTNodeNaturalConversionChecker.cpp b/src/language/utils/ASTNodeNaturalConversionChecker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81a277574753dc3e8cc75871a07d8d8a3b1a0cc6 --- /dev/null +++ b/src/language/utils/ASTNodeNaturalConversionChecker.cpp @@ -0,0 +1,171 @@ +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> + +#include <language/PEGGrammar.hpp> +#include <language/utils/ParseError.hpp> +#include <utils/Exceptions.hpp> + +template <typename RToR1Conversion> +void +ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalTypeConversion( + const ASTNode& node, + const ASTNodeDataType& data_type, + const ASTNodeDataType& target_data_type) const +{ + if (not isNaturalConversion(data_type, target_data_type)) { + if constexpr (std::is_same_v<RToR1ConversionStrategy, AllowRToR1Conversion>) { + if (((target_data_type == ASTNodeDataType::vector_t) and (target_data_type.dimension() == 1)) or + ((target_data_type == ASTNodeDataType::matrix_t) and (target_data_type.nbRows() == 1) and + (target_data_type.nbColumns() == 1))) { + if (isNaturalConversion(data_type, ASTNodeDataType::build<ASTNodeDataType::double_t>())) { + return; + } + } + } + std::ostringstream error_message; + error_message << "invalid implicit conversion: "; + error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type) + << rang::fg::reset; + + if ((data_type == ASTNodeDataType::undefined_t) or (target_data_type == ASTNodeDataType::undefined_t)) { + // LCOV_EXCL_START + throw UnexpectedError(error_message.str()); + // LCOV_EXCL_STOP + } else { + throw ParseError(error_message.str(), node.begin()); + } + } +} + +template <typename RToR1Conversion> +void +ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalExpressionConversion( + const ASTNode& node, + const ASTNodeDataType& data_type, + const ASTNodeDataType& target_data_type) const +{ + if (target_data_type == ASTNodeDataType::typename_t) { + this->_checkIsNaturalExpressionConversion(node, data_type, target_data_type.contentType()); + } else if (target_data_type == ASTNodeDataType::vector_t) { + switch (data_type) { + case ASTNodeDataType::list_t: { + const auto& content_type_list = data_type.contentTypeList(); + if (content_type_list.size() != target_data_type.dimension()) { + std::ostringstream os; + os << "incompatible dimensions in affectation: expecting " << target_data_type.dimension() << ", but provided " + << content_type_list.size(); + throw ParseError(os.str(), std::vector{node.begin()}); + } + + Assert(content_type_list.size() == node.children.size()); + for (size_t i = 0; i < content_type_list.size(); ++i) { + const auto& child_type = *content_type_list[i]; + const auto& child_node = *node.children[i]; + Assert(child_type == child_node.m_data_type); + this->_checkIsNaturalExpressionConversion(child_node, child_type, + ASTNodeDataType::build<ASTNodeDataType::double_t>()); + } + + break; + } + case ASTNodeDataType::vector_t: { + if (data_type.dimension() != target_data_type.dimension()) { + std::ostringstream error_message; + error_message << "invalid implicit conversion: "; + error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type) + << rang::fg::reset; + throw ParseError(error_message.str(), std::vector{node.begin()}); + } + break; + } + case ASTNodeDataType::int_t: { + if (node.is_type<language::integer>()) { + if (std::stoi(node.string()) == 0) { + break; + } + } + [[fallthrough]]; + } + default: { + this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); + } + } + } else if (target_data_type == ASTNodeDataType::matrix_t) { + switch (data_type) { + case ASTNodeDataType::list_t: { + const auto& content_type_list = data_type.contentTypeList(); + if (content_type_list.size() != (target_data_type.nbRows() * target_data_type.nbColumns())) { + std::ostringstream os; + os << "incompatible dimensions in affectation: expecting " + << target_data_type.nbRows() * target_data_type.nbColumns() << ", but provided " << content_type_list.size(); + throw ParseError(os.str(), std::vector{node.begin()}); + } + + Assert(content_type_list.size() == node.children.size()); + for (size_t i = 0; i < content_type_list.size(); ++i) { + const auto& child_type = *content_type_list[i]; + const auto& child_node = *node.children[i]; + Assert(child_type == child_node.m_data_type); + this->_checkIsNaturalExpressionConversion(child_node, child_type, + ASTNodeDataType::build<ASTNodeDataType::double_t>()); + } + + break; + } + case ASTNodeDataType::matrix_t: { + if ((data_type.nbRows() != target_data_type.nbRows()) or + (data_type.nbColumns() != target_data_type.nbColumns())) { + std::ostringstream error_message; + error_message << "invalid implicit conversion: "; + error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type) + << rang::fg::reset; + throw ParseError(error_message.str(), std::vector{node.begin()}); + } + break; + } + case ASTNodeDataType::int_t: { + if (node.is_type<language::integer>()) { + if (std::stoi(node.string()) == 0) { + break; + } + } + [[fallthrough]]; + } + default: { + this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); + } + } + } else if (target_data_type == ASTNodeDataType::tuple_t) { + const ASTNodeDataType& target_content_type = target_data_type.contentType(); + if (node.m_data_type == ASTNodeDataType::tuple_t) { + this->_checkIsNaturalExpressionConversion(node, data_type.contentType(), target_content_type); + } else if (node.m_data_type == ASTNodeDataType::list_t) { + for (const auto& child : node.children) { + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(*child, target_data_type.contentType()); + } + } else { + this->_checkIsNaturalExpressionConversion(node, data_type, target_content_type); + } + } else { + this->_checkIsNaturalTypeConversion(node, data_type, target_data_type); + } +} + +template <typename RToR1Conversion> +ASTNodeNaturalConversionChecker<RToR1Conversion>::ASTNodeNaturalConversionChecker( + const ASTNode& data_node, + const ASTNodeDataType& target_data_type) +{ + this->_checkIsNaturalExpressionConversion(data_node, data_node.m_data_type, target_data_type); +} + +template <typename RToR1Conversion> +ASTNodeNaturalConversionChecker<RToR1Conversion>::ASTNodeNaturalConversionChecker( + const ASTNodeSubDataType& data_node_sub_data_type, + const ASTNodeDataType& target_data_type) +{ + this->_checkIsNaturalExpressionConversion(data_node_sub_data_type.m_parent_node, data_node_sub_data_type.m_data_type, + target_data_type); +} + +template class ASTNodeNaturalConversionChecker<AllowRToR1Conversion>; +template class ASTNodeNaturalConversionChecker<DisallowRToR1Conversion>; diff --git a/src/language/ast/ASTNodeNaturalConversionChecker.hpp b/src/language/utils/ASTNodeNaturalConversionChecker.hpp similarity index 81% rename from src/language/ast/ASTNodeNaturalConversionChecker.hpp rename to src/language/utils/ASTNodeNaturalConversionChecker.hpp index 9d89e3b4925b6aaebfbc06c92fe30164e9b061fd..9e96e49b90d5f7931aa4796a10f67dee0cfff6db 100644 --- a/src/language/ast/ASTNodeNaturalConversionChecker.hpp +++ b/src/language/utils/ASTNodeNaturalConversionChecker.hpp @@ -2,12 +2,23 @@ #define AST_NODE_NATURAL_CONVERSION_CHECKER_HPP #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeDataType.hpp> #include <language/ast/ASTNodeSubDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> +struct AllowRToR1Conversion +{ +}; + +struct DisallowRToR1Conversion +{ +}; + +template <typename RToR1Conversion = DisallowRToR1Conversion> class ASTNodeNaturalConversionChecker { private: + using RToR1ConversionStrategy = RToR1Conversion; + void _checkIsNaturalTypeConversion(const ASTNode& ast_node, const ASTNodeDataType& data_type, const ASTNodeDataType& target_data_type) const; diff --git a/src/language/utils/ASTPrinter.cpp b/src/language/utils/ASTPrinter.cpp index 9bf18bdcf88f6399283750dac9ab0b38d8bf0400..8cc88b530997eb3cc9cccb176c3812c41dceea1f 100644 --- a/src/language/utils/ASTPrinter.cpp +++ b/src/language/utils/ASTPrinter.cpp @@ -14,9 +14,10 @@ ASTPrinter::_print(std::ostream& os, const ASTNode& node) const } os << rang::fg::reset; - if (node.is_type<language::name>() or node.is_type<language::literal>() or node.is_type<language::integer>() or - node.is_type<language::real>()) { + if (node.is_type<language::name>() or node.is_type<language::integer>() or node.is_type<language::real>()) { os << ':' << rang::fgB::green << node.string() << rang::fg::reset; + } else if (node.is_type<language::literal>()) { + os << ":\"" << rang::fgB::green << node.string() << rang::fg::reset << '"'; } if (m_info & static_cast<InfoBaseType>(Info::data_type)) { diff --git a/src/language/utils/AffectationMangler.hpp b/src/language/utils/AffectationMangler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..85aaab35bfbabde5b9bdbc9b250f8440e68829e6 --- /dev/null +++ b/src/language/utils/AffectationMangler.hpp @@ -0,0 +1,52 @@ +#ifndef AFFECTATION_MANGLER_HPP +#define AFFECTATION_MANGLER_HPP + +#include <language/utils/ASTNodeDataType.hpp> +#include <utils/Exceptions.hpp> + +#include <string> + +namespace language +{ +struct eq_op; +struct multiplyeq_op; +struct divideeq_op; +struct pluseq_op; +struct minuseq_op; +} // namespace language + +template <typename AffectationOperatorT> +std::string +affectationMangler(const ASTNodeDataType& lhs, const ASTNodeDataType& rhs) +{ + const std::string lhs_name = dataTypeName(lhs); + + const std::string operator_name = [] { + if constexpr (std::is_same_v<language::eq_op, AffectationOperatorT>) { + return "="; + } else if constexpr (std::is_same_v<language::multiplyeq_op, AffectationOperatorT>) { + return "*="; + } else if constexpr (std::is_same_v<language::divideeq_op, AffectationOperatorT>) { + return "/="; + } else if constexpr (std::is_same_v<language::pluseq_op, AffectationOperatorT>) { + return "+="; + } else if constexpr (std::is_same_v<language::minuseq_op, AffectationOperatorT>) { + return "-="; + } else { + static_assert(std::is_same_v<language::eq_op, AffectationOperatorT>, "undefined affectation operator"); + } + }(); + + const std::string rhs_name = [&]() -> std::string { + if (rhs == ASTNodeDataType::list_t) { + return "list"; + } else if (rhs == ASTNodeDataType::tuple_t) { + return "tuple"; + } else { + return dataTypeName(rhs); + } + }(); + return lhs_name + " " + operator_name + " " + rhs_name; +} + +#endif // AFFECTATION_MANGLER_HPP diff --git a/src/language/utils/AffectationProcessorBuilder.hpp b/src/language/utils/AffectationProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8d11599a01c67566e0f88fcb34cad581bc76de45 --- /dev/null +++ b/src/language/utils/AffectationProcessorBuilder.hpp @@ -0,0 +1,99 @@ +#ifndef AFFECTATION_PROCESSOR_BUILDER_HPP +#define AFFECTATION_PROCESSOR_BUILDER_HPP + +#include <algebra/TinyVector.hpp> +#include <language/PEGGrammar.hpp> +#include <language/node_processor/AffectationProcessor.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/IAffectationProcessorBuilder.hpp> + +#include <type_traits> + +template <typename OperatorT, typename ValueT, typename DataT> +class AffectationProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + if constexpr (std::is_same_v<ValueT, TinyVector<1>> and std::is_same_v<DataT, int64_t> and + std::is_same_v<OperatorT, language::eq_op>) { + // Special treatment for the case 0 -> R^1 + if ((node.children[1]->is_type<language::integer>()) and (std::stoi(node.children[1]->string()) == 0)) { + return std::make_unique<AffectationFromZeroProcessor<ValueT>>(node); + } else { + return std::make_unique<AffectationProcessor<OperatorT, ValueT, DataT>>(node); + } + } else { + return std::make_unique<AffectationProcessor<OperatorT, ValueT, DataT>>(node); + } + } +}; + +template <typename ValueT> +class AffectationToTupleProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationToTupleProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<AffectationToTupleProcessor<ValueT>>(node); + } +}; + +template <typename ValueT> +class AffectationToTupleFromListProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationToTupleFromListProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + ASTNodeNaturalConversionChecker(*node.children[1], node.children[0]->m_data_type); + return std::make_unique<AffectationToTupleFromListProcessor<ValueT>>(node); + } +}; + +template <typename OperatorT, typename ValueT> +class AffectationToTinyVectorFromListProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationToTinyVectorFromListProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<AffectationToTinyVectorFromListProcessor<OperatorT, ValueT>>(node); + } +}; + +template <typename OperatorT, typename ValueT> +class AffectationToTinyMatrixFromListProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationToTinyMatrixFromListProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<AffectationToTinyMatrixFromListProcessor<OperatorT, ValueT>>(node); + } +}; + +template <typename OperatorT, typename ValueT> +class AffectationFromZeroProcessorBuilder final : public IAffectationProcessorBuilder +{ + public: + AffectationFromZeroProcessorBuilder() = default; + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + if (std::stoi(node.children[1]->string()) == 0) { + return std::make_unique<AffectationFromZeroProcessor<ValueT>>(node); + } else { + throw ParseError("invalid integral value (0 is the solely valid value)", std::vector{node.children[1]->begin()}); + } + } +}; + +#endif // AFFECTATION_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/AffectationRegisterForB.cpp b/src/language/utils/AffectationRegisterForB.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69feaee7066a7d95f17bf539df75e46cf3fe27c5 --- /dev/null +++ b/src/language/utils/AffectationRegisterForB.cpp @@ -0,0 +1,10 @@ +#include <language/utils/AffectationRegisterForB.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +AffectationRegisterForB::AffectationRegisterForB() +{ + BasicAffectationRegisterFor<bool>{}; +} diff --git a/src/language/utils/AffectationRegisterForB.hpp b/src/language/utils/AffectationRegisterForB.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d6507ff7a1edf0d48839924a8b3aa051ff441c09 --- /dev/null +++ b/src/language/utils/AffectationRegisterForB.hpp @@ -0,0 +1,10 @@ +#ifndef AFFECTATION_REGISTER_FOR_B_HPP +#define AFFECTATION_REGISTER_FOR_B_HPP + +class AffectationRegisterForB +{ + public: + AffectationRegisterForB(); +}; + +#endif // AFFECTATION_REGISTER_FOR_B_HPP diff --git a/src/language/utils/AffectationRegisterForN.cpp b/src/language/utils/AffectationRegisterForN.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7d54008f22845f8233b8475091c8cec774c6756 --- /dev/null +++ b/src/language/utils/AffectationRegisterForN.cpp @@ -0,0 +1,120 @@ +#include <language/utils/AffectationRegisterForN.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +AffectationRegisterForN::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository + .addAffectation<language::eq_op>(N, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, uint64_t, bool>>()); + + repository.addAffectation< + language::eq_op>(N, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, uint64_t, int64_t>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(N), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<uint64_t>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(N), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<uint64_t>>()); +} + +void +AffectationRegisterForN::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addAffectation< + language::pluseq_op>(N, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, uint64_t, bool>>()); + + repository.addAffectation< + language::pluseq_op>(N, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, uint64_t, uint64_t>>()); + + repository.addAffectation< + language::pluseq_op>(N, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, uint64_t, int64_t>>()); +} + +void +AffectationRegisterForN::_register_minuseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addAffectation< + language::minuseq_op>(N, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, uint64_t, bool>>()); + + repository.addAffectation< + language::minuseq_op>(N, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, uint64_t, uint64_t>>()); + + repository.addAffectation< + language::minuseq_op>(N, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, uint64_t, int64_t>>()); +} + +void +AffectationRegisterForN::_register_multiplyeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addAffectation< + language::multiplyeq_op>(N, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::multiplyeq_op, uint64_t, bool>>()); + + repository.addAffectation<language::multiplyeq_op>(N, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, uint64_t, uint64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(N, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, uint64_t, int64_t>>()); +} + +void +AffectationRegisterForN::_register_divideeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addAffectation< + language::divideeq_op>(N, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, uint64_t, bool>>()); + + repository.addAffectation< + language::divideeq_op>(N, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, uint64_t, uint64_t>>()); + + repository.addAffectation< + language::divideeq_op>(N, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, uint64_t, int64_t>>()); +} + +AffectationRegisterForN::AffectationRegisterForN() +{ + BasicAffectationRegisterFor<uint64_t>{}; + + this->_register_eq_op(); + this->_register_pluseq_op(); + this->_register_minuseq_op(); + this->_register_multiplyeq_op(); + this->_register_divideeq_op(); +} diff --git a/src/language/utils/AffectationRegisterForN.hpp b/src/language/utils/AffectationRegisterForN.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e97a652c094635fbaa54b9b267c37f3a169367a0 --- /dev/null +++ b/src/language/utils/AffectationRegisterForN.hpp @@ -0,0 +1,17 @@ +#ifndef AFFECTATION_REGISTER_FOR_N_HPP +#define AFFECTATION_REGISTER_FOR_N_HPP + +class AffectationRegisterForN +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + void _register_minuseq_op(); + void _register_multiplyeq_op(); + void _register_divideeq_op(); + + public: + AffectationRegisterForN(); +}; + +#endif // AFFECTATION_REGISTER_FOR_N_HPP diff --git a/src/language/utils/AffectationRegisterForR.cpp b/src/language/utils/AffectationRegisterForR.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae7c63c4031f39871f64713c39cc7dcba8e0bae4 --- /dev/null +++ b/src/language/utils/AffectationRegisterForR.cpp @@ -0,0 +1,141 @@ +#include <language/utils/AffectationRegisterForR.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +AffectationRegisterForR::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository + .addAffectation<language::eq_op>(R, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, double, bool>>()); + + repository.addAffectation< + language::eq_op>(R, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, double, uint64_t>>()); + + repository + .addAffectation<language::eq_op>(R, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, double, int64_t>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<double>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R), + ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<double>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<double>>()); +} + +void +AffectationRegisterForR::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addAffectation< + language::pluseq_op>(R, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, double, bool>>()); + + repository.addAffectation< + language::pluseq_op>(R, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, double, uint64_t>>()); + + repository.addAffectation< + language::pluseq_op>(R, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, double, int64_t>>()); + + repository.addAffectation< + language::pluseq_op>(R, R, std::make_shared<AffectationProcessorBuilder<language::pluseq_op, double, double>>()); +} + +void +AffectationRegisterForR::_register_minuseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addAffectation< + language::minuseq_op>(R, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, double, bool>>()); + + repository.addAffectation< + language::minuseq_op>(R, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, double, uint64_t>>()); + + repository.addAffectation< + language::minuseq_op>(R, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, double, int64_t>>()); + + repository.addAffectation< + language::minuseq_op>(R, R, std::make_shared<AffectationProcessorBuilder<language::minuseq_op, double, double>>()); +} + +void +AffectationRegisterForR::_register_multiplyeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addAffectation< + language::multiplyeq_op>(R, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::multiplyeq_op, double, bool>>()); + + repository.addAffectation<language::multiplyeq_op>(R, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, double, uint64_t>>()); + + repository.addAffectation< + language::multiplyeq_op>(R, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::multiplyeq_op, double, int64_t>>()); + + repository.addAffectation< + language::multiplyeq_op>(R, R, + std::make_shared<AffectationProcessorBuilder<language::multiplyeq_op, double, double>>()); +} + +void +AffectationRegisterForR::_register_divideeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addAffectation< + language::divideeq_op>(R, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, double, bool>>()); + + repository.addAffectation< + language::divideeq_op>(R, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, double, uint64_t>>()); + + repository.addAffectation< + language::divideeq_op>(R, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, double, int64_t>>()); + + repository.addAffectation< + language::divideeq_op>(R, R, + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, double, double>>()); +} + +AffectationRegisterForR::AffectationRegisterForR() +{ + BasicAffectationRegisterFor<double>{}; + this->_register_eq_op(); + this->_register_pluseq_op(); + this->_register_minuseq_op(); + this->_register_multiplyeq_op(); + this->_register_divideeq_op(); +} diff --git a/src/language/utils/AffectationRegisterForR.hpp b/src/language/utils/AffectationRegisterForR.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a004a92a276edad8dc52b015c2802aa4055aa11 --- /dev/null +++ b/src/language/utils/AffectationRegisterForR.hpp @@ -0,0 +1,17 @@ +#ifndef AFFECTATION_REGISTER_FOR_R_HPP +#define AFFECTATION_REGISTER_FOR_R_HPP + +class AffectationRegisterForR +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + void _register_minuseq_op(); + void _register_multiplyeq_op(); + void _register_divideeq_op(); + + public: + AffectationRegisterForR(); +}; + +#endif // AFFECTATION_REGISTER_FOR_R_HPP diff --git a/src/language/utils/AffectationRegisterForRn.cpp b/src/language/utils/AffectationRegisterForRn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d38dc64c1d245d4d6c530e891ffd6eaecd63372 --- /dev/null +++ b/src/language/utils/AffectationRegisterForRn.cpp @@ -0,0 +1,143 @@ +#include <language/utils/AffectationRegisterForRn.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <size_t Dimension> +void +AffectationRegisterForRn<Dimension>::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rn = ASTNodeDataType::build<ASTNodeDataType::vector_t>(Dimension); + + repository.addAffectation< + language::eq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationFromZeroProcessorBuilder<language::eq_op, TinyVector<Dimension>>>()); + + repository.addAffectation<language::eq_op>(Rn, + ASTNodeDataType::build<ASTNodeDataType::list_t>( + std::vector<std::shared_ptr<const ASTNodeDataType>>{}), + std::make_shared<AffectationToTinyVectorFromListProcessorBuilder< + language::eq_op, TinyVector<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rn), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyVector<Dimension>>>()); +} + +template <> +void +AffectationRegisterForRn<1>::_register_eq_op() +{ + constexpr size_t Dimension = 1; + + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rn = ASTNodeDataType::build<ASTNodeDataType::vector_t>(Dimension); + + repository.addAffectation< + language::eq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyVector<Dimension>, bool>>()); + + repository.addAffectation< + language::eq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyVector<Dimension>, uint64_t>>()); + + repository.addAffectation< + language::eq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyVector<Dimension>, int64_t>>()); + + repository.addAffectation< + language::eq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyVector<Dimension>, double>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rn), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyVector<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rn), + ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyVector<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rn), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyVector<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rn), + ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyVector<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRn<Dimension>::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rn = ASTNodeDataType::build<ASTNodeDataType::vector_t>(Dimension); + + repository + .addAffectation<language::pluseq_op>(Rn, Rn, + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, TinyVector<Dimension>, TinyVector<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRn<Dimension>::_register_minuseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rn = ASTNodeDataType::build<ASTNodeDataType::vector_t>(Dimension); + + repository + .addAffectation<language::minuseq_op>(Rn, Rn, + std::make_shared<AffectationProcessorBuilder< + language::minuseq_op, TinyVector<Dimension>, TinyVector<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRn<Dimension>::_register_multiplyeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rn = ASTNodeDataType::build<ASTNodeDataType::vector_t>(Dimension); + + repository.addAffectation<language::multiplyeq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyVector<Dimension>, bool>>()); + + repository.addAffectation<language::multiplyeq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyVector<Dimension>, uint64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyVector<Dimension>, int64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Rn, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyVector<Dimension>, double>>()); +} + +template <size_t Dimension> +AffectationRegisterForRn<Dimension>::AffectationRegisterForRn() +{ + BasicAffectationRegisterFor<TinyVector<Dimension>>{}; + this->_register_eq_op(); + this->_register_pluseq_op(); + this->_register_minuseq_op(); + this->_register_multiplyeq_op(); +} + +template class AffectationRegisterForRn<1>; +template class AffectationRegisterForRn<2>; +template class AffectationRegisterForRn<3>; diff --git a/src/language/utils/AffectationRegisterForRn.hpp b/src/language/utils/AffectationRegisterForRn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..337a340b81b7a59d35ee11dd5946d2313560ebe8 --- /dev/null +++ b/src/language/utils/AffectationRegisterForRn.hpp @@ -0,0 +1,19 @@ +#ifndef AFFECTATION_REGISTER_FOR_RN_HPP +#define AFFECTATION_REGISTER_FOR_RN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class AffectationRegisterForRn +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + void _register_minuseq_op(); + void _register_multiplyeq_op(); + + public: + AffectationRegisterForRn(); +}; + +#endif // AFFECTATION_REGISTER_FOR_RN_HPP diff --git a/src/language/utils/AffectationRegisterForRnxn.cpp b/src/language/utils/AffectationRegisterForRnxn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efa95d92bbcfd32af37f6c3f69477d7ee985a560 --- /dev/null +++ b/src/language/utils/AffectationRegisterForRnxn.cpp @@ -0,0 +1,143 @@ +#include <language/utils/AffectationRegisterForRnxn.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationFromZeroProcessorBuilder<language::eq_op, TinyMatrix<Dimension>>>()); + + repository.addAffectation<language::eq_op>(Rnxn, + ASTNodeDataType::build<ASTNodeDataType::list_t>( + std::vector<std::shared_ptr<const ASTNodeDataType>>{}), + std::make_shared<AffectationToTinyMatrixFromListProcessorBuilder< + language::eq_op, TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); +} + +template <> +void +AffectationRegisterForRnxn<1>::_register_eq_op() +{ + constexpr size_t Dimension = 1; + + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, bool>>()); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, uint64_t>>()); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, int64_t>>()); + + repository.addAffectation< + language::eq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, TinyMatrix<Dimension>, double>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); + + repository + .addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Rnxn), + ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository + .addAffectation<language::pluseq_op>(Rnxn, Rnxn, + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, TinyMatrix<Dimension>, TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_minuseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository + .addAffectation<language::minuseq_op>(Rnxn, Rnxn, + std::make_shared<AffectationProcessorBuilder< + language::minuseq_op, TinyMatrix<Dimension>, TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +void +AffectationRegisterForRnxn<Dimension>::_register_multiplyeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, bool>>()); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, uint64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, int64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Rnxn, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, TinyMatrix<Dimension>, double>>()); +} + +template <size_t Dimension> +AffectationRegisterForRnxn<Dimension>::AffectationRegisterForRnxn() +{ + BasicAffectationRegisterFor<TinyMatrix<Dimension>>{}; + this->_register_eq_op(); + this->_register_pluseq_op(); + this->_register_minuseq_op(); + this->_register_multiplyeq_op(); +} + +template class AffectationRegisterForRnxn<1>; +template class AffectationRegisterForRnxn<2>; +template class AffectationRegisterForRnxn<3>; diff --git a/src/language/utils/AffectationRegisterForRnxn.hpp b/src/language/utils/AffectationRegisterForRnxn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..33f2ad07d3911f2b7488cd10b321be5e4f59f5e4 --- /dev/null +++ b/src/language/utils/AffectationRegisterForRnxn.hpp @@ -0,0 +1,19 @@ +#ifndef AFFECTATION_REGISTER_FOR_RNXN_HPP +#define AFFECTATION_REGISTER_FOR_RNXN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class AffectationRegisterForRnxn +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + void _register_minuseq_op(); + void _register_multiplyeq_op(); + + public: + AffectationRegisterForRnxn(); +}; + +#endif // AFFECTATION_REGISTER_FOR_RNXN_HPP diff --git a/src/language/utils/AffectationRegisterForString.cpp b/src/language/utils/AffectationRegisterForString.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7946aeb45fc72e7569e5d8299e184c220baa7106 --- /dev/null +++ b/src/language/utils/AffectationRegisterForString.cpp @@ -0,0 +1,152 @@ +#include <language/utils/AffectationRegisterForString.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +AffectationRegisterForString::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto string_t = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, bool>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, uint64_t>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, int64_t>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, double_t>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyVector<1>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyVector<2>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyVector<3>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyMatrix<1>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyMatrix<2>>>()); + + repository.addAffectation< + language::eq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + std::make_shared<AffectationProcessorBuilder<language::eq_op, std::string, TinyMatrix<3>>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_t), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + std::make_shared<AffectationToTupleProcessorBuilder<std::string>>()); +} + +void +AffectationRegisterForString::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto string_t = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + + repository.addAffectation<language::pluseq_op>(string_t, string_t, + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, std::string>>()); + + repository.addAffectation< + language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, bool>>()); + + repository.addAffectation< + language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, uint64_t>>()); + + repository.addAffectation< + language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, int64_t>>()); + + repository.addAffectation< + language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::double_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, std::string, double>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyVector<1>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyVector<2>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyVector<3>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyMatrix<1>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyMatrix<2>>>()); + + repository.addAffectation<language::pluseq_op>(string_t, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + std::make_shared<AffectationProcessorBuilder< + language::pluseq_op, std::string, TinyMatrix<3>>>()); +} + +AffectationRegisterForString::AffectationRegisterForString() +{ + BasicAffectationRegisterFor<std::string>{}; + this->_register_eq_op(); + this->_register_pluseq_op(); +} diff --git a/src/language/utils/AffectationRegisterForString.hpp b/src/language/utils/AffectationRegisterForString.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6b6a47d4fcfe913f486d1069e73920795e1d451 --- /dev/null +++ b/src/language/utils/AffectationRegisterForString.hpp @@ -0,0 +1,14 @@ +#ifndef AFFECTATION_REGISTER_FOR_STRING_HPP +#define AFFECTATION_REGISTER_FOR_STRING_HPP + +class AffectationRegisterForString +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + + public: + AffectationRegisterForString(); +}; + +#endif // AFFECTATION_REGISTER_FOR_STRING_HPP diff --git a/src/language/utils/AffectationRegisterForZ.cpp b/src/language/utils/AffectationRegisterForZ.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9937a0edbaedce2ae054d02fc13d7c4e1168c698 --- /dev/null +++ b/src/language/utils/AffectationRegisterForZ.cpp @@ -0,0 +1,119 @@ +#include <language/utils/AffectationRegisterForZ.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +AffectationRegisterForZ::_register_eq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository + .addAffectation<language::eq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, int64_t, bool>>()); + + repository.addAffectation< + language::eq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::eq_op, int64_t, uint64_t>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Z), + ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<int64_t>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(Z), + ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationToTupleProcessorBuilder<int64_t>>()); +} + +void +AffectationRegisterForZ::_register_pluseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addAffectation< + language::pluseq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, int64_t, bool>>()); + + repository.addAffectation< + language::pluseq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, int64_t, uint64_t>>()); + + repository.addAffectation< + language::pluseq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::pluseq_op, int64_t, int64_t>>()); +} + +void +AffectationRegisterForZ::_register_minuseq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addAffectation< + language::minuseq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, int64_t, bool>>()); + + repository.addAffectation< + language::minuseq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, int64_t, uint64_t>>()); + + repository.addAffectation< + language::minuseq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::minuseq_op, int64_t, int64_t>>()); +} + +void +AffectationRegisterForZ::_register_multiplyeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addAffectation< + language::multiplyeq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::multiplyeq_op, int64_t, bool>>()); + + repository.addAffectation<language::multiplyeq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, int64_t, uint64_t>>()); + + repository.addAffectation<language::multiplyeq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder< + language::multiplyeq_op, int64_t, int64_t>>()); +} + +void +AffectationRegisterForZ::_register_divideeq_op() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addAffectation< + language::divideeq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::bool_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, int64_t, bool>>()); + + repository.addAffectation< + language::divideeq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, int64_t, uint64_t>>()); + + repository.addAffectation< + language::divideeq_op>(Z, ASTNodeDataType::build<ASTNodeDataType::int_t>(), + std::make_shared<AffectationProcessorBuilder<language::divideeq_op, int64_t, int64_t>>()); +} + +AffectationRegisterForZ::AffectationRegisterForZ() +{ + BasicAffectationRegisterFor<int64_t>{}; + this->_register_eq_op(); + this->_register_pluseq_op(); + this->_register_minuseq_op(); + this->_register_multiplyeq_op(); + this->_register_divideeq_op(); +} diff --git a/src/language/utils/AffectationRegisterForZ.hpp b/src/language/utils/AffectationRegisterForZ.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a559c94fda9ef5a089d85af9f9773693ce9c9fb4 --- /dev/null +++ b/src/language/utils/AffectationRegisterForZ.hpp @@ -0,0 +1,17 @@ +#ifndef AFFECTATION_REGISTER_FOR_Z_HPP +#define AFFECTATION_REGISTER_FOR_Z_HPP + +class AffectationRegisterForZ +{ + private: + void _register_eq_op(); + void _register_pluseq_op(); + void _register_minuseq_op(); + void _register_multiplyeq_op(); + void _register_divideeq_op(); + + public: + AffectationRegisterForZ(); +}; + +#endif // AFFECTATION_REGISTER_FOR_Z_HPP diff --git a/src/language/utils/BasicAffectationRegistrerFor.hpp b/src/language/utils/BasicAffectationRegistrerFor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ccae4c32b596b39ce7b663da38e31e70c8dfbebc --- /dev/null +++ b/src/language/utils/BasicAffectationRegistrerFor.hpp @@ -0,0 +1,36 @@ +#ifndef BASIC_AFFECTATION_REGISTRER_FOR_HPP +#define BASIC_AFFECTATION_REGISTRER_FOR_HPP + +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <typename T> +class BasicAffectationRegisterFor +{ + public: + BasicAffectationRegisterFor() : BasicAffectationRegisterFor(ast_node_data_type_from<T>) {} + + BasicAffectationRegisterFor(const ASTNodeDataType& ast_node_data_type) + { + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addAffectation<language::eq_op>(ast_node_data_type, ast_node_data_type, + std::make_shared<AffectationProcessorBuilder<language::eq_op, T, T>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ast_node_data_type), + ast_node_data_type, + std::make_shared<AffectationToTupleProcessorBuilder<T>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ast_node_data_type), + ASTNodeDataType::build<ASTNodeDataType::list_t>( + std::vector<std::shared_ptr<const ASTNodeDataType>>{}), + std::make_shared<AffectationToTupleFromListProcessorBuilder<T>>()); + + repository.addAffectation<language::eq_op>(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ast_node_data_type), + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ast_node_data_type), + std::make_shared<AffectationToTupleFromListProcessorBuilder<T>>()); + } +}; + +#endif // BASIC_AFFECTATION_REGISTRER_FOR_HPP diff --git a/src/language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp b/src/language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp new file mode 100644 index 0000000000000000000000000000000000000000..758117182df0c170c1c00c00ea38938668eae2c6 --- /dev/null +++ b/src/language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp @@ -0,0 +1,34 @@ +#ifndef BASIC_BINARY_OPERATOR_REGISTER_COMPARISON_OF_HPP +#define BASIC_BINARY_OPERATOR_REGISTER_COMPARISON_OF_HPP + +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <typename A_DataT, typename B_DataT> +struct BasicBinaryOperatorRegisterComparisonOf +{ + BasicBinaryOperatorRegisterComparisonOf() + { + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::eqeq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::eqeq_op, bool, A_DataT, B_DataT>>()); + + repository.addBinaryOperator<language::not_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::not_eq_op, bool, A_DataT, B_DataT>>()); + + repository.addBinaryOperator<language::greater_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::greater_op, bool, A_DataT, B_DataT>>()); + + repository.addBinaryOperator<language::greater_or_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::greater_or_eq_op, bool, A_DataT, B_DataT>>()); + + repository.addBinaryOperator<language::lesser_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::lesser_op, bool, A_DataT, B_DataT>>()); + + repository.addBinaryOperator<language::lesser_or_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::lesser_or_eq_op, bool, A_DataT, B_DataT>>()); + } +}; + +#endif // BASIC_BINARY_OPERATOR_REGISTER_COMPARISON_OF_HPP diff --git a/src/language/utils/BinaryOperatorMangler.hpp b/src/language/utils/BinaryOperatorMangler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3648f3fb860a6207b70ad712aee9935b669e493 --- /dev/null +++ b/src/language/utils/BinaryOperatorMangler.hpp @@ -0,0 +1,67 @@ +#ifndef BINARY_OPERATOR_MANGLER_HPP +#define BINARY_OPERATOR_MANGLER_HPP + +#include <language/utils/ASTNodeDataType.hpp> +#include <utils/Exceptions.hpp> + +#include <string> + +namespace language +{ +struct multiply_op; +struct divide_op; +struct plus_op; +struct minus_op; + +struct or_op; +struct and_op; +struct xor_op; + +struct greater_op; +struct greater_or_eq_op; +struct lesser_op; +struct lesser_or_eq_op; +struct eqeq_op; +struct not_eq_op; +} // namespace language + +template <typename BinaryOperatorT> +std::string +binaryOperatorMangler(const ASTNodeDataType& lhs_data_type, const ASTNodeDataType& rhs_data_type) +{ + const std::string operator_name = [] { + if constexpr (std::is_same_v<BinaryOperatorT, language::multiply_op>) { + return "*"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::divide_op>) { + return "/"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::plus_op>) { + return "+"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::minus_op>) { + return "-"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::or_op>) { + return "or"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::and_op>) { + return "and"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::xor_op>) { + return "xor"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::greater_op>) { + return ">"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::greater_or_eq_op>) { + return ">="; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::lesser_op>) { + return "<"; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::lesser_or_eq_op>) { + return "<="; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::eqeq_op>) { + return "=="; + } else if constexpr (std::is_same_v<BinaryOperatorT, language::not_eq_op>) { + return "!="; + } else { + static_assert(std::is_same_v<language::multiply_op, BinaryOperatorT>, "undefined binary operator"); + } + }(); + + return dataTypeName(lhs_data_type) + " " + operator_name + " " + dataTypeName(rhs_data_type); +} + +#endif // BINARY_OPERATOR_MANGLER_HPP diff --git a/src/language/utils/BinaryOperatorProcessorBuilder.hpp b/src/language/utils/BinaryOperatorProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b8cb572d1e79c1882ed0288d79a4b73c6c130e5c --- /dev/null +++ b/src/language/utils/BinaryOperatorProcessorBuilder.hpp @@ -0,0 +1,43 @@ +#ifndef BINARY_OPERATOR_PROCESSOR_BUILDER_HPP +#define BINARY_OPERATOR_PROCESSOR_BUILDER_HPP + +#include <algebra/TinyVector.hpp> +#include <language/PEGGrammar.hpp> +#include <language/node_processor/BinaryExpressionProcessor.hpp> +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/IBinaryOperatorProcessorBuilder.hpp> + +#include <type_traits> + +template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT> +class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder +{ + public: + BinaryOperatorProcessorBuilder() = default; + + ASTNodeDataType + getDataTypeOfA() const + { + return ast_node_data_type_from<A_DataT>; + } + + ASTNodeDataType + getDataTypeOfB() const + { + return ast_node_data_type_from<B_DataT>; + } + + ASTNodeDataType + getReturnValueType() const + { + return ast_node_data_type_from<ValueT>; + } + + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<BinaryExpressionProcessor<OperatorT, ValueT, A_DataT, B_DataT>>(node); + } +}; + +#endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForB.cpp b/src/language/utils/BinaryOperatorRegisterForB.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cba803aaf3e3d7d4c72be7b69ca1a835203a49b2 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForB.cpp @@ -0,0 +1,70 @@ +#include <language/utils/BinaryOperatorRegisterForB.hpp> + +#include <language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp> + +void +BinaryOperatorRegisterForB::_register_comparisons() +{ + BasicBinaryOperatorRegisterComparisonOf<bool, bool>{}; +} + +void +BinaryOperatorRegisterForB::_register_logical_operators() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::and_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::and_op, bool, bool, bool>>()); + + repository.addBinaryOperator<language::or_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::or_op, bool, bool, bool>>()); + + repository.addBinaryOperator<language::xor_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::xor_op, bool, bool, bool>>()); +} + +void +BinaryOperatorRegisterForB::_register_plus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::plus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::plus_op, uint64_t, bool, bool>>()); +} + +void +BinaryOperatorRegisterForB::_register_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, int64_t, bool, bool>>()); +} + +void +BinaryOperatorRegisterForB::_register_multiply() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, uint64_t, bool, bool>>()); +} + +void +BinaryOperatorRegisterForB::_register_divide() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::divide_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::divide_op, uint64_t, bool, bool>>()); +} + +BinaryOperatorRegisterForB::BinaryOperatorRegisterForB() +{ + this->_register_comparisons(); + this->_register_logical_operators(); + this->_register_plus(); + this->_register_minus(); + this->_register_multiply(); + this->_register_divide(); +} diff --git a/src/language/utils/BinaryOperatorRegisterForB.hpp b/src/language/utils/BinaryOperatorRegisterForB.hpp new file mode 100644 index 0000000000000000000000000000000000000000..651411bf7f2cd8e8386d26c071220d36cb79a691 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForB.hpp @@ -0,0 +1,18 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_B_HPP +#define BINARY_OPERATOR_REGISTER_FOR_B_HPP + +class BinaryOperatorRegisterForB +{ + private: + void _register_comparisons(); + void _register_logical_operators(); + void _register_plus(); + void _register_minus(); + void _register_multiply(); + void _register_divide(); + + public: + BinaryOperatorRegisterForB(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_B_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForN.cpp b/src/language/utils/BinaryOperatorRegisterForN.cpp new file mode 100644 index 0000000000000000000000000000000000000000..66d589405545924af9e2de137f7738c9b40b8430 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForN.cpp @@ -0,0 +1,50 @@ +#include <language/utils/BinaryOperatorRegisterForN.hpp> + +#include <language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp> + +void +BinaryOperatorRegisterForN::_register_comparisons() +{ + BasicBinaryOperatorRegisterComparisonOf<bool, uint64_t>{}; + BasicBinaryOperatorRegisterComparisonOf<uint64_t, bool>{}; + + BasicBinaryOperatorRegisterComparisonOf<uint64_t, uint64_t>{}; +} + +template <typename OperatorT> +void +BinaryOperatorRegisterForN::_register_arithmetic() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, uint64_t, uint64_t, bool>>()); + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, uint64_t, bool, uint64_t>>()); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, uint64_t, uint64_t, uint64_t>>()); +} + +void +BinaryOperatorRegisterForN::_register_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, int64_t, uint64_t, bool>>()); + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, int64_t, bool, uint64_t>>()); + + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, int64_t, uint64_t, uint64_t>>()); +} + +BinaryOperatorRegisterForN::BinaryOperatorRegisterForN() +{ + this->_register_comparisons(); + this->_register_arithmetic<language::plus_op>(); + this->_register_minus(); + this->_register_arithmetic<language::multiply_op>(); + this->_register_arithmetic<language::divide_op>(); +} diff --git a/src/language/utils/BinaryOperatorRegisterForN.hpp b/src/language/utils/BinaryOperatorRegisterForN.hpp new file mode 100644 index 0000000000000000000000000000000000000000..80dc52b013f36e05e08f7bf561f2ac3b41ccbfbf --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForN.hpp @@ -0,0 +1,16 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_N_HPP +#define BINARY_OPERATOR_REGISTER_FOR_N_HPP + +class BinaryOperatorRegisterForN +{ + private: + template <typename OperatorT> + void _register_arithmetic(); + void _register_comparisons(); + void _register_minus(); + + public: + BinaryOperatorRegisterForN(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_N_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForR.cpp b/src/language/utils/BinaryOperatorRegisterForR.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f69988925e373fbd7369bff3d0bb87860b1aa3d --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForR.cpp @@ -0,0 +1,52 @@ +#include <language/utils/BinaryOperatorRegisterForR.hpp> + +#include <language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp> + +void +BinaryOperatorRegisterForR::_register_comparisons() +{ + BasicBinaryOperatorRegisterComparisonOf<bool, double>{}; + BasicBinaryOperatorRegisterComparisonOf<double, bool>{}; + + BasicBinaryOperatorRegisterComparisonOf<uint64_t, double>{}; + BasicBinaryOperatorRegisterComparisonOf<double, uint64_t>{}; + + BasicBinaryOperatorRegisterComparisonOf<int64_t, double>{}; + BasicBinaryOperatorRegisterComparisonOf<double, int64_t>{}; + + BasicBinaryOperatorRegisterComparisonOf<double, double>{}; +} + +template <typename OperatorT> +void +BinaryOperatorRegisterForR::_register_arithmetic() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, double, bool>>()); + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, bool, double>>()); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, double, uint64_t>>()); + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, uint64_t, double_t>>()); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, double, int64_t>>()); + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, int64_t, double_t>>()); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, double, double, double>>()); +} + +BinaryOperatorRegisterForR::BinaryOperatorRegisterForR() +{ + this->_register_comparisons(); + this->_register_arithmetic<language::plus_op>(); + this->_register_arithmetic<language::minus_op>(); + this->_register_arithmetic<language::multiply_op>(); + this->_register_arithmetic<language::divide_op>(); +} diff --git a/src/language/utils/BinaryOperatorRegisterForR.hpp b/src/language/utils/BinaryOperatorRegisterForR.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbde82fb3d36e565ef84d99b95b58c6c0556eb01 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForR.hpp @@ -0,0 +1,15 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_R_HPP +#define BINARY_OPERATOR_REGISTER_FOR_R_HPP + +class BinaryOperatorRegisterForR +{ + private: + template <typename OperatorT> + void _register_arithmetic(); + void _register_comparisons(); + + public: + BinaryOperatorRegisterForR(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_R_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForRn.cpp b/src/language/utils/BinaryOperatorRegisterForRn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41e5b031f7e3db02d8e59a9962e028c046f813a8 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForRn.cpp @@ -0,0 +1,67 @@ +#include <language/utils/BinaryOperatorRegisterForRn.hpp> + +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <size_t Dimension> +void +BinaryOperatorRegisterForRn<Dimension>::_register_comparisons() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rn = TinyVector<Dimension>; + + repository.addBinaryOperator<language::eqeq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::eqeq_op, bool, Rn, Rn>>()); + + repository.addBinaryOperator<language::not_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::not_eq_op, bool, Rn, Rn>>()); +} + +template <size_t Dimension> +void +BinaryOperatorRegisterForRn<Dimension>::_register_product_by_a_scalar() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rn = TinyVector<Dimension>; + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, bool, Rn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, uint64_t, Rn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, int64_t, Rn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, double, Rn>>()); +} + +template <size_t Dimension> +template <typename OperatorT> +void +BinaryOperatorRegisterForRn<Dimension>::_register_arithmetic() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rn = TinyVector<Dimension>; + + repository.addBinaryOperator<OperatorT>(std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, Rn, Rn, Rn>>()); +} + +template <size_t Dimension> +BinaryOperatorRegisterForRn<Dimension>::BinaryOperatorRegisterForRn() +{ + this->_register_comparisons(); + + this->_register_product_by_a_scalar(); + + this->_register_arithmetic<language::plus_op>(); + this->_register_arithmetic<language::minus_op>(); +} + +template class BinaryOperatorRegisterForRn<1>; +template class BinaryOperatorRegisterForRn<2>; +template class BinaryOperatorRegisterForRn<3>; diff --git a/src/language/utils/BinaryOperatorRegisterForRn.hpp b/src/language/utils/BinaryOperatorRegisterForRn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b9f27d4165d236b913e3393e342a144019fe6015 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForRn.hpp @@ -0,0 +1,21 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_RN_HPP +#define BINARY_OPERATOR_REGISTER_FOR_RN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class BinaryOperatorRegisterForRn +{ + private: + void _register_comparisons(); + + void _register_product_by_a_scalar(); + + template <typename OperatorT> + void _register_arithmetic(); + + public: + BinaryOperatorRegisterForRn(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_RN_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.cpp b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d8850e79a4129c9a2098baa14b869f48bb8d1ce --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForRnxn.cpp @@ -0,0 +1,83 @@ +#include <language/utils/BinaryOperatorRegisterForRnxn.hpp> + +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_comparisons() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + + repository.addBinaryOperator<language::eqeq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::eqeq_op, bool, Rnxn, Rnxn>>()); + + repository.addBinaryOperator<language::not_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::not_eq_op, bool, Rnxn, Rnxn>>()); +} + +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_scalar() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, bool, Rnxn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, uint64_t, Rnxn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, int64_t, Rnxn>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rnxn, double, Rnxn>>()); +} + +template <size_t Dimension> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_product_by_a_vector() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + using Rn = TinyVector<Dimension>; + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, Rn, Rnxn, Rn>>()); +} + +template <size_t Dimension> +template <typename OperatorT> +void +BinaryOperatorRegisterForRnxn<Dimension>::_register_arithmetic() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + using Rnxn = TinyMatrix<Dimension>; + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, Rnxn, Rnxn, Rnxn>>()); +} + +template <size_t Dimension> +BinaryOperatorRegisterForRnxn<Dimension>::BinaryOperatorRegisterForRnxn() +{ + this->_register_comparisons(); + + this->_register_product_by_a_scalar(); + this->_register_product_by_a_vector(); + + this->_register_arithmetic<language::plus_op>(); + this->_register_arithmetic<language::minus_op>(); + this->_register_arithmetic<language::multiply_op>(); +} + +template class BinaryOperatorRegisterForRnxn<1>; +template class BinaryOperatorRegisterForRnxn<2>; +template class BinaryOperatorRegisterForRnxn<3>; diff --git a/src/language/utils/BinaryOperatorRegisterForRnxn.hpp b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..594740b629b0a7201d64139fd8e2ce4a12b38cd2 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForRnxn.hpp @@ -0,0 +1,22 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_RNXN_HPP +#define BINARY_OPERATOR_REGISTER_FOR_RNXN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class BinaryOperatorRegisterForRnxn +{ + private: + void _register_comparisons(); + + void _register_product_by_a_scalar(); + void _register_product_by_a_vector(); + + template <typename OperatorT> + void _register_arithmetic(); + + public: + BinaryOperatorRegisterForRnxn(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_RNXN_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForString.cpp b/src/language/utils/BinaryOperatorRegisterForString.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e89cbe00502938d817405c3f3f3d6d991f0433e --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForString.cpp @@ -0,0 +1,43 @@ +#include <language/utils/BinaryOperatorRegisterForString.hpp> + +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> +#include <language/utils/ConcatExpressionProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +BinaryOperatorRegisterForString::_register_comparisons() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::eqeq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::eqeq_op, bool, std::string, std::string>>()); + + repository.addBinaryOperator<language::not_eq_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::not_eq_op, bool, std::string, std::string>>()); +} + +template <typename RHS_T> +void +BinaryOperatorRegisterForString::_register_concat() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::plus_op>(std::make_shared<ConcatExpressionProcessorBuilder<RHS_T>>()); +} + +BinaryOperatorRegisterForString::BinaryOperatorRegisterForString() +{ + this->_register_comparisons(); + + this->_register_concat<bool>(); + this->_register_concat<unsigned long>(); + this->_register_concat<long>(); + this->_register_concat<double>(); + this->_register_concat<TinyVector<1>>(); + this->_register_concat<TinyVector<2>>(); + this->_register_concat<TinyVector<3>>(); + this->_register_concat<TinyMatrix<1>>(); + this->_register_concat<TinyMatrix<2>>(); + this->_register_concat<TinyMatrix<3>>(); + this->_register_concat<std::string>(); +} diff --git a/src/language/utils/BinaryOperatorRegisterForString.hpp b/src/language/utils/BinaryOperatorRegisterForString.hpp new file mode 100644 index 0000000000000000000000000000000000000000..36af13ee6442086244ee5234ad496f3580f79c22 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForString.hpp @@ -0,0 +1,16 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_STRING_HPP +#define BINARY_OPERATOR_REGISTER_FOR_STRING_HPP + +class BinaryOperatorRegisterForString +{ + private: + void _register_comparisons(); + + template <typename RHS_T> + void _register_concat(); + + public: + BinaryOperatorRegisterForString(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_STRING_HPP diff --git a/src/language/utils/BinaryOperatorRegisterForZ.cpp b/src/language/utils/BinaryOperatorRegisterForZ.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc8b22d117a8bebc557efff84083b8e45dcf8602 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForZ.cpp @@ -0,0 +1,44 @@ +#include <language/utils/BinaryOperatorRegisterForZ.hpp> + +#include <language/utils/BasicBinaryOperatorRegisterComparisonOf.hpp> + +void +BinaryOperatorRegisterForZ::_register_comparisons() +{ + BasicBinaryOperatorRegisterComparisonOf<bool, int64_t>{}; + BasicBinaryOperatorRegisterComparisonOf<int64_t, bool>{}; + + BasicBinaryOperatorRegisterComparisonOf<uint64_t, int64_t>{}; + BasicBinaryOperatorRegisterComparisonOf<int64_t, uint64_t>{}; + + BasicBinaryOperatorRegisterComparisonOf<int64_t, int64_t>{}; +} + +template <typename OperatorT> +void +BinaryOperatorRegisterForZ::_register_arithmetic() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, int64_t, int64_t, bool>>()); + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, int64_t, bool, int64_t>>()); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, int64_t, int64_t, uint64_t>>()); + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, int64_t, uint64_t, int64_t>>()); + + repository.addBinaryOperator<OperatorT>( + std::make_shared<BinaryOperatorProcessorBuilder<OperatorT, int64_t, int64_t, int64_t>>()); +} + +BinaryOperatorRegisterForZ::BinaryOperatorRegisterForZ() +{ + this->_register_comparisons(); + this->_register_arithmetic<language::plus_op>(); + this->_register_arithmetic<language::minus_op>(); + this->_register_arithmetic<language::multiply_op>(); + this->_register_arithmetic<language::divide_op>(); +} diff --git a/src/language/utils/BinaryOperatorRegisterForZ.hpp b/src/language/utils/BinaryOperatorRegisterForZ.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f9a7c30c2f8e3674e6ebda7e9069153a6162f13 --- /dev/null +++ b/src/language/utils/BinaryOperatorRegisterForZ.hpp @@ -0,0 +1,15 @@ +#ifndef BINARY_OPERATOR_REGISTER_FOR_Z_HPP +#define BINARY_OPERATOR_REGISTER_FOR_Z_HPP + +class BinaryOperatorRegisterForZ +{ + private: + template <typename OperatorT> + void _register_arithmetic(); + void _register_comparisons(); + + public: + BinaryOperatorRegisterForZ(); +}; + +#endif // BINARY_OPERATOR_REGISTER_FOR_Z_HPP diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp index 3974398e1803856585f91f09565b29da3d877533..afb991acadcb36236bcd55a34c7840ca399ab6f0 100644 --- a/src/language/utils/BuiltinFunctionEmbedder.hpp +++ b/src/language/utils/BuiltinFunctionEmbedder.hpp @@ -1,7 +1,7 @@ #ifndef BUILTIN_FUNCTION_EMBEDDER_HPP #define BUILTIN_FUNCTION_EMBEDDER_HPP -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/DataHandler.hpp> #include <language/utils/DataVariant.hpp> @@ -126,14 +126,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder PUGS_INLINE ASTNodeDataType _getDataType() const { - if constexpr (is_data_variant_v<T>) { - return ast_node_data_type_from<T>; - } else if constexpr (std::is_same_v<void, T>) { - return ASTNodeDataType::void_t; - } else { - Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t); - return ast_node_data_type_from<T>; - } + Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t); + return ast_node_data_type_from<T>; } template <size_t I> @@ -262,14 +256,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder PUGS_INLINE ASTNodeDataType _getDataType() const { - if constexpr (is_data_variant_v<T>) { - return ast_node_data_type_from<T>; - } else if constexpr (std::is_same_v<void, T>) { - return ASTNodeDataType::void_t; - } else { - Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t); - return ast_node_data_type_from<T>; - } + Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t); + return ast_node_data_type_from<T>; } public: diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 6a8216a9ca62a440c7eb84111cbb87d26ae6b1d3..22c0ae52f184a7f18d852057cbe9b0616f1905e3 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -1,10 +1,37 @@ # ------------------- Source files -------------------- add_library(PugsLanguageUtils + AffectationRegisterForB.cpp + AffectationRegisterForN.cpp + AffectationRegisterForR.cpp + AffectationRegisterForRn.cpp + AffectationRegisterForRnxn.cpp + AffectationRegisterForString.cpp + AffectationRegisterForZ.cpp ASTDotPrinter.cpp + ASTExecutionInfo.cpp + ASTNodeDataType.cpp + ASTNodeNaturalConversionChecker.cpp ASTPrinter.cpp + BinaryOperatorRegisterForB.cpp + BinaryOperatorRegisterForN.cpp + BinaryOperatorRegisterForR.cpp + BinaryOperatorRegisterForRn.cpp + BinaryOperatorRegisterForRnxn.cpp + BinaryOperatorRegisterForString.cpp + BinaryOperatorRegisterForZ.cpp DataVariant.cpp EmbeddedData.cpp + IncDecOperatorRegisterForN.cpp + IncDecOperatorRegisterForR.cpp + IncDecOperatorRegisterForZ.cpp + OperatorRepository.cpp + UnaryOperatorRegisterForB.cpp + UnaryOperatorRegisterForN.cpp + UnaryOperatorRegisterForR.cpp + UnaryOperatorRegisterForRn.cpp + UnaryOperatorRegisterForRnxn.cpp + UnaryOperatorRegisterForZ.cpp ) diff --git a/src/language/utils/ConcatExpressionProcessorBuilder.hpp b/src/language/utils/ConcatExpressionProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8fed359622883368f8bdc3cbdfbebf524e314c50 --- /dev/null +++ b/src/language/utils/ConcatExpressionProcessorBuilder.hpp @@ -0,0 +1,42 @@ +#ifndef CONCAT_EXPRESSION_PROCESSOR_BUILDER_HPP +#define CONCAT_EXPRESSION_PROCESSOR_BUILDER_HPP + +#include <language/PEGGrammar.hpp> +#include <language/node_processor/ConcatExpressionProcessor.hpp> +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/IBinaryOperatorProcessorBuilder.hpp> + +#include <type_traits> + +template <typename B_DataT> +class ConcatExpressionProcessorBuilder final : public IBinaryOperatorProcessorBuilder +{ + public: + ConcatExpressionProcessorBuilder() = default; + + ASTNodeDataType + getDataTypeOfA() const + { + return ast_node_data_type_from<std::string>; + } + + ASTNodeDataType + getDataTypeOfB() const + { + return ast_node_data_type_from<B_DataT>; + } + + ASTNodeDataType + getReturnValueType() const + { + return ast_node_data_type_from<std::string>; + } + + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<ConcatExpressionProcessor<B_DataT>>(node); + } +}; + +#endif // CONCAT_EXPRESSION_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/DataVariant.hpp b/src/language/utils/DataVariant.hpp index 13744172a73c2e3c54d97412cd654ae7211ad42c..6add9e52853bcbf58ee42e4f2869649dab43d37d 100644 --- a/src/language/utils/DataVariant.hpp +++ b/src/language/utils/DataVariant.hpp @@ -1,6 +1,7 @@ #ifndef DATA_VARIANT_HPP #define DATA_VARIANT_HPP +#include <algebra/TinyMatrix.hpp> #include <algebra/TinyVector.hpp> #include <language/utils/EmbeddedData.hpp> #include <language/utils/FunctionSymbolId.hpp> @@ -22,6 +23,9 @@ using DataVariant = std::variant<std::monostate, TinyVector<1>, TinyVector<2>, TinyVector<3>, + TinyMatrix<1>, + TinyMatrix<2>, + TinyMatrix<3>, EmbeddedData, std::vector<bool>, std::vector<uint64_t>, @@ -31,6 +35,9 @@ using DataVariant = std::variant<std::monostate, std::vector<TinyVector<1>>, std::vector<TinyVector<2>>, std::vector<TinyVector<3>>, + std::vector<TinyMatrix<1>>, + std::vector<TinyMatrix<2>>, + std::vector<TinyMatrix<3>>, std::vector<EmbeddedData>, AggregateDataVariant, FunctionSymbolId>; @@ -71,14 +78,16 @@ class AggregateDataVariant // LCOV_EXCL_LINE } PUGS_INLINE - DataVariant& operator[](size_t i) + DataVariant& + operator[](size_t i) { Assert(i < m_data_vector.size()); return m_data_vector[i]; } PUGS_INLINE - const DataVariant& operator[](size_t i) const + const DataVariant& + operator[](size_t i) const { Assert(i < m_data_vector.size()); return m_data_vector[i]; diff --git a/src/language/utils/FunctionTable.hpp b/src/language/utils/FunctionTable.hpp index 43232fd60e4482a73dca1a4b257c7a6c0b6ef28a..6c6450c22ff8ff845ddd0c90dbcead1eec52203c 100644 --- a/src/language/utils/FunctionTable.hpp +++ b/src/language/utils/FunctionTable.hpp @@ -2,7 +2,7 @@ #define FUNCTION_TABLE_HPP #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/utils/DataVariant.hpp> #include <utils/PugsAssert.hpp> @@ -63,14 +63,16 @@ class FunctionTable } PUGS_INLINE - FunctionDescriptor& operator[](size_t function_id) + FunctionDescriptor& + operator[](size_t function_id) { Assert(function_id < m_function_descriptor_list.size()); return m_function_descriptor_list[function_id]; } PUGS_INLINE - const FunctionDescriptor& operator[](size_t function_id) const + const FunctionDescriptor& + operator[](size_t function_id) const { Assert(function_id < m_function_descriptor_list.size()); return m_function_descriptor_list[function_id]; diff --git a/src/language/utils/IAffectationProcessorBuilder.hpp b/src/language/utils/IAffectationProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6d5f8540df9435bba4b86e6bee22f980b5727d7b --- /dev/null +++ b/src/language/utils/IAffectationProcessorBuilder.hpp @@ -0,0 +1,17 @@ +#ifndef I_AFFECTATION_PROCESSOR_BUILDER_HPP +#define I_AFFECTATION_PROCESSOR_BUILDER_HPP + +class ASTNode; +class INodeProcessor; + +#include <memory> + +class IAffectationProcessorBuilder +{ + public: + virtual std::unique_ptr<INodeProcessor> getNodeProcessor(ASTNode& node) const = 0; + + virtual ~IAffectationProcessorBuilder() = default; +}; + +#endif // I_AFFECTATION_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/IBinaryOperatorProcessorBuilder.hpp b/src/language/utils/IBinaryOperatorProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46a3a23010ed6e8bd0d6373f80f53d8bfd45e8c7 --- /dev/null +++ b/src/language/utils/IBinaryOperatorProcessorBuilder.hpp @@ -0,0 +1,23 @@ +#ifndef I_BINARY_OPERATOR_PROCESSOR_BUILDER_HPP +#define I_BINARY_OPERATOR_PROCESSOR_BUILDER_HPP + +class ASTNode; +class INodeProcessor; + +#include <language/utils/ASTNodeDataType.hpp> + +#include <memory> + +class IBinaryOperatorProcessorBuilder +{ + public: + [[nodiscard]] virtual std::unique_ptr<INodeProcessor> getNodeProcessor(ASTNode& node) const = 0; + + [[nodiscard]] virtual ASTNodeDataType getReturnValueType() const = 0; + [[nodiscard]] virtual ASTNodeDataType getDataTypeOfA() const = 0; + [[nodiscard]] virtual ASTNodeDataType getDataTypeOfB() const = 0; + + virtual ~IBinaryOperatorProcessorBuilder() = default; +}; + +#endif // I_BINARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/IIncDecOperatorProcessorBuilder.hpp b/src/language/utils/IIncDecOperatorProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b84974a332326afa293757576efd6e2bdf40795e --- /dev/null +++ b/src/language/utils/IIncDecOperatorProcessorBuilder.hpp @@ -0,0 +1,20 @@ +#ifndef I_INC_DEC_OPERATOR_PROCESSOR_BUILDER_HPP +#define I_INC_DEC_OPERATOR_PROCESSOR_BUILDER_HPP + +class ASTNode; +class INodeProcessor; +#include <language/utils/ASTNodeDataType.hpp> + +#include <memory> + +class IIncDecOperatorProcessorBuilder +{ + public: + [[nodiscard]] virtual std::unique_ptr<INodeProcessor> getNodeProcessor(ASTNode& node) const = 0; + + [[nodiscard]] virtual ASTNodeDataType getReturnValueType() const = 0; + + virtual ~IIncDecOperatorProcessorBuilder() = default; +}; + +#endif // I_INC_DEC_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/IUnaryOperatorProcessorBuilder.hpp b/src/language/utils/IUnaryOperatorProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..19b36e2feb42b9efc558ff5e554532c0f3d32824 --- /dev/null +++ b/src/language/utils/IUnaryOperatorProcessorBuilder.hpp @@ -0,0 +1,21 @@ +#ifndef I_UNARY_OPERATOR_PROCESSOR_BUILDER_HPP +#define I_UNARY_OPERATOR_PROCESSOR_BUILDER_HPP + +class ASTNode; +class INodeProcessor; + +#include <language/utils/ASTNodeDataType.hpp> + +#include <memory> + +class IUnaryOperatorProcessorBuilder +{ + public: + [[nodiscard]] virtual std::unique_ptr<INodeProcessor> getNodeProcessor(ASTNode& node) const = 0; + + [[nodiscard]] virtual ASTNodeDataType getReturnValueType() const = 0; + + virtual ~IUnaryOperatorProcessorBuilder() = default; +}; + +#endif // I_UNARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/IncDecOperatorMangler.hpp b/src/language/utils/IncDecOperatorMangler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed5a84ae44c327116dbddef71880bd7b601fd97f --- /dev/null +++ b/src/language/utils/IncDecOperatorMangler.hpp @@ -0,0 +1,34 @@ +#ifndef INC_DEC_OPERATOR_MANGLER_HPP +#define INC_DEC_OPERATOR_MANGLER_HPP + +#include <language/utils/ASTNodeDataType.hpp> +#include <utils/Exceptions.hpp> + +#include <string> + +namespace language +{ +struct unary_minusminus; +struct unary_plusplus; +struct post_minusminus; +struct post_plusplus; +} // namespace language + +template <typename IncDecOperatorT> +std::string +incDecOperatorMangler(const ASTNodeDataType& operand) +{ + if constexpr (std::is_same_v<language::unary_minusminus, IncDecOperatorT>) { + return std::string{"-- "} + dataTypeName(operand); + } else if constexpr (std::is_same_v<language::post_minusminus, IncDecOperatorT>) { + return dataTypeName(operand) + " --"; + } else if constexpr (std::is_same_v<language::unary_plusplus, IncDecOperatorT>) { + return std::string{"++ "} + dataTypeName(operand); + } else if constexpr (std::is_same_v<language::post_plusplus, IncDecOperatorT>) { + return dataTypeName(operand) + " ++"; + } else { + static_assert(std::is_same_v<language::unary_minusminus, IncDecOperatorT>, "undefined inc/dec operator"); + } +} + +#endif // INC_DEC_OPERATOR_MANGLER_HPP diff --git a/src/language/utils/IncDecOperatorProcessorBuilder.hpp b/src/language/utils/IncDecOperatorProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40da865ddf9c6ed480a859e9a158d21cb5562848 --- /dev/null +++ b/src/language/utils/IncDecOperatorProcessorBuilder.hpp @@ -0,0 +1,31 @@ +#ifndef INC_DEC_OPERATOR_PROCESSOR_BUILDER_HPP +#define INC_DEC_OPERATOR_PROCESSOR_BUILDER_HPP + +#include <algebra/TinyVector.hpp> +#include <language/PEGGrammar.hpp> +#include <language/node_processor/IncDecExpressionProcessor.hpp> +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/IIncDecOperatorProcessorBuilder.hpp> + +#include <type_traits> + +template <typename OperatorT, typename DataT> +class IncDecOperatorProcessorBuilder final : public IIncDecOperatorProcessorBuilder +{ + public: + IncDecOperatorProcessorBuilder() = default; + + ASTNodeDataType + getReturnValueType() const + { + return ast_node_data_type_from<DataT>; + } + + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<IncDecExpressionProcessor<OperatorT, DataT>>(node); + } +}; + +#endif // INC_DEC_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/IncDecOperatorRegisterForN.cpp b/src/language/utils/IncDecOperatorRegisterForN.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e6a6369fcef13242abde08ffb59e75778b4793a --- /dev/null +++ b/src/language/utils/IncDecOperatorRegisterForN.cpp @@ -0,0 +1,58 @@ +#include <language/utils/IncDecOperatorRegisterForN.hpp> + +#include <language/utils/IncDecOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +IncDecOperatorRegisterForN::_register_unary_minusminus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addIncDecOperator<language::unary_minusminus>(N, std::make_shared<IncDecOperatorProcessorBuilder< + language::unary_minusminus, uint64_t>>()); +} + +void +IncDecOperatorRegisterForN::_register_unary_plusplus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addIncDecOperator< + language::unary_plusplus>(N, + std::make_shared<IncDecOperatorProcessorBuilder<language::unary_plusplus, uint64_t>>()); +} + +void +IncDecOperatorRegisterForN::_register_post_minusminus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addIncDecOperator< + language::post_minusminus>(N, + std::make_shared<IncDecOperatorProcessorBuilder<language::post_minusminus, uint64_t>>()); +} + +void +IncDecOperatorRegisterForN::_register_post_plusplus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addIncDecOperator< + language::post_plusplus>(N, std::make_shared<IncDecOperatorProcessorBuilder<language::post_plusplus, uint64_t>>()); +} + +IncDecOperatorRegisterForN::IncDecOperatorRegisterForN() +{ + this->_register_unary_minusminus(); + this->_register_unary_plusplus(); + this->_register_post_minusminus(); + this->_register_post_plusplus(); +} diff --git a/src/language/utils/IncDecOperatorRegisterForN.hpp b/src/language/utils/IncDecOperatorRegisterForN.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3e2682d8a0fb0122f813ec633bb08a37e997922 --- /dev/null +++ b/src/language/utils/IncDecOperatorRegisterForN.hpp @@ -0,0 +1,16 @@ +#ifndef INC_DEC_OPERATOR_REGISTER_FOR_N_HPP +#define INC_DEC_OPERATOR_REGISTER_FOR_N_HPP + +class IncDecOperatorRegisterForN +{ + private: + void _register_unary_minusminus(); + void _register_unary_plusplus(); + void _register_post_minusminus(); + void _register_post_plusplus(); + + public: + IncDecOperatorRegisterForN(); +}; + +#endif // INC_DEC_OPERATOR_REGISTER_FOR_N_HPP diff --git a/src/language/utils/IncDecOperatorRegisterForR.cpp b/src/language/utils/IncDecOperatorRegisterForR.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5031751468f6593d5eab1839cdf1f778c56837e7 --- /dev/null +++ b/src/language/utils/IncDecOperatorRegisterForR.cpp @@ -0,0 +1,58 @@ +#include <language/utils/IncDecOperatorRegisterForR.hpp> + +#include <language/utils/IncDecOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +IncDecOperatorRegisterForR::_register_unary_minusminus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addIncDecOperator<language::unary_minusminus>(R, std::make_shared<IncDecOperatorProcessorBuilder< + language::unary_minusminus, double_t>>()); +} + +void +IncDecOperatorRegisterForR::_register_unary_plusplus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addIncDecOperator< + language::unary_plusplus>(R, + std::make_shared<IncDecOperatorProcessorBuilder<language::unary_plusplus, double_t>>()); +} + +void +IncDecOperatorRegisterForR::_register_post_minusminus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addIncDecOperator< + language::post_minusminus>(R, + std::make_shared<IncDecOperatorProcessorBuilder<language::post_minusminus, double_t>>()); +} + +void +IncDecOperatorRegisterForR::_register_post_plusplus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addIncDecOperator< + language::post_plusplus>(R, std::make_shared<IncDecOperatorProcessorBuilder<language::post_plusplus, double_t>>()); +} + +IncDecOperatorRegisterForR::IncDecOperatorRegisterForR() +{ + this->_register_unary_minusminus(); + this->_register_unary_plusplus(); + this->_register_post_minusminus(); + this->_register_post_plusplus(); +} diff --git a/src/language/utils/IncDecOperatorRegisterForR.hpp b/src/language/utils/IncDecOperatorRegisterForR.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5ede6b9999347788f9447d7b5eeb9fe28b8e4da7 --- /dev/null +++ b/src/language/utils/IncDecOperatorRegisterForR.hpp @@ -0,0 +1,16 @@ +#ifndef INC_DEC_OPERATOR_REGISTER_FOR_R_HPP +#define INC_DEC_OPERATOR_REGISTER_FOR_R_HPP + +class IncDecOperatorRegisterForR +{ + private: + void _register_unary_minusminus(); + void _register_unary_plusplus(); + void _register_post_minusminus(); + void _register_post_plusplus(); + + public: + IncDecOperatorRegisterForR(); +}; + +#endif // INC_DEC_OPERATOR_REGISTER_FOR_Z_HPP diff --git a/src/language/utils/IncDecOperatorRegisterForZ.cpp b/src/language/utils/IncDecOperatorRegisterForZ.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a0bc0faa8bcf3ef0bcbed4837c44f5c9f147aa3 --- /dev/null +++ b/src/language/utils/IncDecOperatorRegisterForZ.cpp @@ -0,0 +1,57 @@ +#include <language/utils/IncDecOperatorRegisterForZ.hpp> + +#include <language/utils/IncDecOperatorProcessorBuilder.hpp> +#include <language/utils/OperatorRepository.hpp> + +void +IncDecOperatorRegisterForZ::_register_unary_minusminus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addIncDecOperator<language::unary_minusminus>(Z, std::make_shared<IncDecOperatorProcessorBuilder< + language::unary_minusminus, int64_t>>()); +} + +void +IncDecOperatorRegisterForZ::_register_unary_plusplus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addIncDecOperator< + language::unary_plusplus>(Z, std::make_shared<IncDecOperatorProcessorBuilder<language::unary_plusplus, int64_t>>()); +} + +void +IncDecOperatorRegisterForZ::_register_post_minusminus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addIncDecOperator< + language::post_minusminus>(N, + std::make_shared<IncDecOperatorProcessorBuilder<language::post_minusminus, int64_t>>()); +} + +void +IncDecOperatorRegisterForZ::_register_post_plusplus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addIncDecOperator< + language::post_plusplus>(Z, std::make_shared<IncDecOperatorProcessorBuilder<language::post_plusplus, int64_t>>()); +} + +IncDecOperatorRegisterForZ::IncDecOperatorRegisterForZ() +{ + this->_register_unary_minusminus(); + this->_register_unary_plusplus(); + this->_register_post_minusminus(); + this->_register_post_plusplus(); +} diff --git a/src/language/utils/IncDecOperatorRegisterForZ.hpp b/src/language/utils/IncDecOperatorRegisterForZ.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b4ab27770b5a144245a297751a48d0b00f8322ba --- /dev/null +++ b/src/language/utils/IncDecOperatorRegisterForZ.hpp @@ -0,0 +1,16 @@ +#ifndef INC_DEC_OPERATOR_REGISTER_FOR_Z_HPP +#define INC_DEC_OPERATOR_REGISTER_FOR_Z_HPP + +class IncDecOperatorRegisterForZ +{ + private: + void _register_unary_minusminus(); + void _register_unary_plusplus(); + void _register_post_minusminus(); + void _register_post_plusplus(); + + public: + IncDecOperatorRegisterForZ(); +}; + +#endif // INC_DEC_OPERATOR_REGISTER_FOR_Z_HPP diff --git a/src/language/utils/OperatorRepository.cpp b/src/language/utils/OperatorRepository.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d34344b967f41c72ae09849a554958b8ebfcc1f3 --- /dev/null +++ b/src/language/utils/OperatorRepository.cpp @@ -0,0 +1,102 @@ +#include <language/utils/OperatorRepository.hpp> + +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/AffectationRegisterForB.hpp> +#include <language/utils/AffectationRegisterForN.hpp> +#include <language/utils/AffectationRegisterForR.hpp> +#include <language/utils/AffectationRegisterForRn.hpp> +#include <language/utils/AffectationRegisterForRnxn.hpp> +#include <language/utils/AffectationRegisterForString.hpp> +#include <language/utils/AffectationRegisterForZ.hpp> + +#include <language/utils/BinaryOperatorRegisterForB.hpp> +#include <language/utils/BinaryOperatorRegisterForN.hpp> +#include <language/utils/BinaryOperatorRegisterForR.hpp> +#include <language/utils/BinaryOperatorRegisterForRn.hpp> +#include <language/utils/BinaryOperatorRegisterForRnxn.hpp> +#include <language/utils/BinaryOperatorRegisterForString.hpp> +#include <language/utils/BinaryOperatorRegisterForZ.hpp> + +#include <language/utils/IncDecOperatorRegisterForN.hpp> +#include <language/utils/IncDecOperatorRegisterForR.hpp> +#include <language/utils/IncDecOperatorRegisterForZ.hpp> + +#include <language/utils/UnaryOperatorRegisterForB.hpp> +#include <language/utils/UnaryOperatorRegisterForN.hpp> +#include <language/utils/UnaryOperatorRegisterForR.hpp> +#include <language/utils/UnaryOperatorRegisterForRn.hpp> +#include <language/utils/UnaryOperatorRegisterForRnxn.hpp> +#include <language/utils/UnaryOperatorRegisterForZ.hpp> + +#include <utils/PugsAssert.hpp> + +OperatorRepository* OperatorRepository::m_instance = nullptr; + +void +OperatorRepository::reset() +{ + m_affectation_builder_list.clear(); + m_binary_operator_builder_list.clear(); + m_inc_dec_operator_builder_list.clear(); + m_unary_operator_builder_list.clear(); + this->_initialize(); +} + +void +OperatorRepository::create() +{ + Assert(m_instance == nullptr, "AffectationRepository was already created"); + m_instance = new OperatorRepository; + m_instance->_initialize(); +} + +void +OperatorRepository::destroy() +{ + Assert(m_instance != nullptr, "AffectationRepository was not created"); + delete m_instance; + m_instance = nullptr; +} + +void +OperatorRepository::_initialize() +{ + AffectationRegisterForB{}; + AffectationRegisterForN{}; + AffectationRegisterForZ{}; + AffectationRegisterForR{}; + AffectationRegisterForRn<1>{}; + AffectationRegisterForRn<2>{}; + AffectationRegisterForRn<3>{}; + AffectationRegisterForRnxn<1>{}; + AffectationRegisterForRnxn<2>{}; + AffectationRegisterForRnxn<3>{}; + AffectationRegisterForString{}; + + BinaryOperatorRegisterForB{}; + BinaryOperatorRegisterForN{}; + BinaryOperatorRegisterForZ{}; + BinaryOperatorRegisterForR{}; + BinaryOperatorRegisterForRn<1>{}; + BinaryOperatorRegisterForRn<2>{}; + BinaryOperatorRegisterForRn<3>{}; + BinaryOperatorRegisterForRnxn<1>{}; + BinaryOperatorRegisterForRnxn<2>{}; + BinaryOperatorRegisterForRnxn<3>{}; + BinaryOperatorRegisterForString{}; + + IncDecOperatorRegisterForN{}; + IncDecOperatorRegisterForR{}; + IncDecOperatorRegisterForZ{}; + + UnaryOperatorRegisterForB{}; + UnaryOperatorRegisterForN{}; + UnaryOperatorRegisterForZ{}; + UnaryOperatorRegisterForR{}; + UnaryOperatorRegisterForRn<1>{}; + UnaryOperatorRegisterForRn<2>{}; + UnaryOperatorRegisterForRn<3>{}; + UnaryOperatorRegisterForRnxn<1>{}; + UnaryOperatorRegisterForRnxn<2>{}; + UnaryOperatorRegisterForRnxn<3>{}; +} diff --git a/src/language/utils/OperatorRepository.hpp b/src/language/utils/OperatorRepository.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6ac61e6e245f8f545e4794ffb4bd4703d4806b82 --- /dev/null +++ b/src/language/utils/OperatorRepository.hpp @@ -0,0 +1,224 @@ +#ifndef OPERATOR_REPOSITORY_HPP +#define OPERATOR_REPOSITORY_HPP + +#include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTNodeDataType.hpp> +#include <language/utils/AffectationMangler.hpp> +#include <language/utils/BinaryOperatorMangler.hpp> +#include <language/utils/IAffectationProcessorBuilder.hpp> +#include <language/utils/IBinaryOperatorProcessorBuilder.hpp> +#include <language/utils/IIncDecOperatorProcessorBuilder.hpp> +#include <language/utils/IUnaryOperatorProcessorBuilder.hpp> +#include <language/utils/IncDecOperatorMangler.hpp> +#include <language/utils/UnaryOperatorMangler.hpp> + +#include <utils/Exceptions.hpp> + +#include <optional> +#include <unordered_map> + +class OperatorRepository +{ + private: + template <typename ProcessorBuilderT> + class Descriptor + { + private: + ASTNodeDataType m_value_type; + std::shared_ptr<const ProcessorBuilderT> m_processor_builder; + + public: + const ASTNodeDataType& + valueType() const + { + return m_value_type; + } + + const std::shared_ptr<const ProcessorBuilderT>& + processorBuilder() const + { + return m_processor_builder; + } + + Descriptor(const ASTNodeDataType& value_type, const std::shared_ptr<const ProcessorBuilderT>& processor_builder) + : m_value_type{value_type}, m_processor_builder{processor_builder} + {} + + Descriptor(const Descriptor&) = default; + Descriptor(Descriptor&&) = default; + Descriptor() = default; + ~Descriptor() = default; + }; + + std::unordered_map<std::string, Descriptor<IAffectationProcessorBuilder>> m_affectation_builder_list; + + std::unordered_map<std::string, Descriptor<IIncDecOperatorProcessorBuilder>> m_inc_dec_operator_builder_list; + + std::unordered_map<std::string, Descriptor<IBinaryOperatorProcessorBuilder>> m_binary_operator_builder_list; + + std::unordered_map<std::string, Descriptor<IUnaryOperatorProcessorBuilder>> m_unary_operator_builder_list; + + void _initialize(); + + public: + void reset(); + + template <typename BinaryOperatorTypeT> + void + addBinaryOperator(const std::shared_ptr<const IBinaryOperatorProcessorBuilder>& processor_builder) + { + const std::string binary_operator_type_name = + binaryOperatorMangler<BinaryOperatorTypeT>(processor_builder->getDataTypeOfA(), + processor_builder->getDataTypeOfB()); + if (not m_binary_operator_builder_list + .try_emplace(binary_operator_type_name, + Descriptor{processor_builder->getReturnValueType(), processor_builder}) + .second) { + // LCOV_EXCL_START + throw UnexpectedError(binary_operator_type_name + " has already an entry"); + // LCOV_EXCL_STOP + } + } + + template <typename OperatorTypeT> + void + addAffectation(const ASTNodeDataType& lhs_type, + const ASTNodeDataType& rhs_type, + const std::shared_ptr<const IAffectationProcessorBuilder>& processor_builder) + { + const std::string affectation_type_name = affectationMangler<OperatorTypeT>(lhs_type, rhs_type); + if (not m_affectation_builder_list + .try_emplace(affectation_type_name, + Descriptor{ASTNodeDataType::build<ASTNodeDataType::void_t>(), processor_builder}) + .second) { + // LCOV_EXCL_START + throw UnexpectedError(affectation_type_name + " has already an entry"); + // LCOV_EXCL_STOP + } + } + + template <typename OperatorTypeT> + void + addIncDecOperator(const ASTNodeDataType& operand_type, + const std::shared_ptr<const IIncDecOperatorProcessorBuilder>& processor_builder) + { + const std::string inc_dec_operator_type_name = incDecOperatorMangler<OperatorTypeT>(operand_type); + if (auto [i_descriptor, success] = + m_inc_dec_operator_builder_list.try_emplace(inc_dec_operator_type_name, + Descriptor{processor_builder->getReturnValueType(), + processor_builder}); + not success) { + // LCOV_EXCL_START + throw UnexpectedError(inc_dec_operator_type_name + " has already an entry"); + // LCOV_EXCL_STOP + } + } + + template <typename OperatorTypeT> + void + addUnaryOperator(const ASTNodeDataType& operand_type, + const std::shared_ptr<const IUnaryOperatorProcessorBuilder>& processor_builder) + { + const std::string unary_operator_type_name = unaryOperatorMangler<OperatorTypeT>(operand_type); + if (auto [i_descriptor, success] = + m_unary_operator_builder_list.try_emplace(unary_operator_type_name, + Descriptor{processor_builder->getReturnValueType(), + processor_builder}); + not success) { + // LCOV_EXCL_START + throw UnexpectedError(unary_operator_type_name + " has already an entry"); + // LCOV_EXCL_STOP + } + } + + [[nodiscard]] std::optional<std::shared_ptr<const IAffectationProcessorBuilder>> + getAffectationProcessorBuilder(const std::string& name) const + { + auto&& processor_builder = m_affectation_builder_list.find(name); + if (processor_builder != m_affectation_builder_list.end()) { + return processor_builder->second.processorBuilder(); + } + return {}; + } + + [[nodiscard]] std::optional<std::shared_ptr<const IIncDecOperatorProcessorBuilder>> + getIncDecProcessorBuilder(const std::string& name) const + { + auto&& processor_builder = m_inc_dec_operator_builder_list.find(name); + if (processor_builder != m_inc_dec_operator_builder_list.end()) { + return processor_builder->second.processorBuilder(); + } + return {}; + } + + [[nodiscard]] std::optional<std::shared_ptr<const IBinaryOperatorProcessorBuilder>> + getBinaryProcessorBuilder(const std::string& name) const + { + auto&& processor_builder = m_binary_operator_builder_list.find(name); + if (processor_builder != m_binary_operator_builder_list.end()) { + return processor_builder->second.processorBuilder(); + } + return {}; + } + + [[nodiscard]] std::optional<std::shared_ptr<const IUnaryOperatorProcessorBuilder>> + getUnaryProcessorBuilder(const std::string& name) const + { + auto&& processor_builder = m_unary_operator_builder_list.find(name); + if (processor_builder != m_unary_operator_builder_list.end()) { + return processor_builder->second.processorBuilder(); + } + return {}; + } + + [[nodiscard]] std::optional<ASTNodeDataType> + getIncDecOperatorValueType(const std::string& name) const + { + auto&& processor_builder = m_inc_dec_operator_builder_list.find(name); + if (processor_builder != m_inc_dec_operator_builder_list.end()) { + return processor_builder->second.valueType(); + } + return {}; + } + + [[nodiscard]] std::optional<ASTNodeDataType> + getBinaryOperatorValueType(const std::string& name) const + { + auto&& processor_builder = m_binary_operator_builder_list.find(name); + if (processor_builder != m_binary_operator_builder_list.end()) { + return processor_builder->second.valueType(); + } + return {}; + } + + [[nodiscard]] std::optional<ASTNodeDataType> + getUnaryOperatorValueType(const std::string& name) const + { + auto&& processor_builder = m_unary_operator_builder_list.find(name); + if (processor_builder != m_unary_operator_builder_list.end()) { + return processor_builder->second.valueType(); + } + return {}; + } + + static void create(); + + PUGS_INLINE + static OperatorRepository& + instance() + { + Assert(m_instance != nullptr); + return *m_instance; + } + + static void destroy(); + + private: + static OperatorRepository* m_instance; + + OperatorRepository() = default; + + ~OperatorRepository() = default; +}; + +#endif // OPERATOR_REPOSITORY_HPP diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index 660cae99e206b3320a9d28861fbc2ee0236d6022..9b57a908e6341291d14e5e8f9c80aa8a87275dcc 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -2,8 +2,8 @@ #define PUGS_FUNCTION_ADAPTER_HPP #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeDataType.hpp> #include <language/node_processor/ExecutionPolicy.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/SymbolTable.hpp> #include <utils/Array.hpp> @@ -36,12 +36,14 @@ class PugsFunctionAdapter<OutputType(InputType...)> template <size_t I> [[nodiscard]] PUGS_INLINE static bool - _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept + _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT) { using Arg = std::tuple_element_t<I, InputTuple>; constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>; - const ASTNodeDataType& arg_data_type = arg_expression.m_data_type; + + Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t); + const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType(); return isNaturalConversion(expected_input_data_type, arg_data_type); } @@ -55,53 +57,28 @@ class PugsFunctionAdapter<OutputType(InputType...)> } [[nodiscard]] PUGS_INLINE static bool - _checkValidInputDataType(const ASTNode& input_expression) noexcept + _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept { if constexpr (NArgs == 1) { - return _checkValidArgumentDataType<0>(input_expression); + return _checkValidArgumentDataType<0>(input_domain_expression); } else { - if (input_expression.children.size() != NArgs) { + if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or + (input_domain_expression.children.size() != NArgs)) { return false; } using IndexSequence = std::make_index_sequence<NArgs>; - return _checkAllInputDataType(input_expression, IndexSequence{}); + return _checkAllInputDataType(input_domain_expression, IndexSequence{}); } } [[nodiscard]] PUGS_INLINE static bool - _checkValidOutputDataType(const ASTNode& return_expression) noexcept + _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT) { constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>; - const ASTNodeDataType& return_data_type = return_expression.m_data_type; + const ASTNodeDataType& return_data_type = output_domain_expression.m_data_type.contentType(); - if (not isNaturalConversion(return_data_type, expected_return_data_type)) { - if (expected_return_data_type == ASTNodeDataType::vector_t) { - if (return_data_type == ASTNodeDataType::list_t) { - if (expected_return_data_type.dimension() != return_expression.children.size()) { - return false; - } else { - for (const auto& child : return_expression.children) { - const ASTNodeDataType& data_type = child->m_data_type; - if (not isNaturalConversion(data_type, ast_node_data_type_from<double>)) { - return false; - } - } - } - } else if ((expected_return_data_type.dimension() == 1) and - isNaturalConversion(return_data_type, ast_node_data_type_from<double>)) { - return true; - } else if (return_data_type == ast_node_data_type_from<int64_t>) { - // 0 is the only valid value for vectors - return (return_expression.string() == "0"); - } else { - return false; - } - } else { - return false; - } - } - return true; + return isNaturalConversion(return_data_type, expected_return_data_type); } template <typename Arg, typename... RemainingArgs> @@ -124,16 +101,16 @@ class PugsFunctionAdapter<OutputType(InputType...)> PUGS_INLINE static void _checkFunction(const FunctionDescriptor& function) { - bool has_valid_input = _checkValidInputDataType(*function.definitionNode().children[0]); - bool has_valid_output = _checkValidOutputDataType(*function.definitionNode().children[1]); + bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]); + bool has_valid_output = _checkValidOutputDomain(*function.domainMappingNode().children[1]); - if (not(has_valid_input and has_valid_output)) { + if (not(has_valid_input_domain and has_valid_output)) { std::ostringstream error_message; error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>) << rang::style::reset << '\n' << "note: provided function " << rang::fgB::magenta << function.name() << ": " - << function.domainMappingNode().string() << rang::style::reset << std::ends; + << function.domainMappingNode().string() << rang::style::reset; throw NormalError(error_message.str()); } } @@ -252,6 +229,85 @@ class PugsFunctionAdapter<OutputType(InputType...)> } // LCOV_EXCL_STOP } + } else if constexpr (is_tiny_matrix_v<OutputType>) { + switch (data_type) { + case ASTNodeDataType::list_t: { + return [](DataVariant&& result) -> OutputType { + AggregateDataVariant& v = std::get<AggregateDataVariant>(result); + OutputType x; + + for (size_t i = 0, l = 0; i < x.dimension(); ++i) { + for (size_t j = 0; j < x.dimension(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Aij_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_arithmetic_v<Aij_T>) { + x(i, j) = Aij; + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting arithmetic value"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + return x; + }; + } + case ASTNodeDataType::matrix_t: { + return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; + } + case ASTNodeDataType::bool_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return + [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::unsigned_int_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { + return OutputType(static_cast<double>(std::get<uint64_t>(result))); + }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { + return OutputType{static_cast<double>(std::get<int64_t>(result))}; + }; + } else { + // If this point is reached must be a 0 matrix + return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; }; + } + } + case ASTNodeDataType::double_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + } + // LCOV_EXCL_STOP + } } else if constexpr (std::is_arithmetic_v<OutputType>) { switch (data_type) { case ASTNodeDataType::bool_t: { diff --git a/src/language/utils/SymbolTable.hpp b/src/language/utils/SymbolTable.hpp index 3eba77ec9c495a50183acda709abf030cbfc522b..6aa089bd8de7aa5a3a6a6cf2a8115e77de25dbe8 100644 --- a/src/language/utils/SymbolTable.hpp +++ b/src/language/utils/SymbolTable.hpp @@ -1,7 +1,7 @@ #ifndef SYMBOL_TABLE_HPP #define SYMBOL_TABLE_HPP -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> #include <language/utils/DataVariant.hpp> #include <language/utils/EmbedderTable.hpp> #include <language/utils/FunctionTable.hpp> @@ -25,7 +25,7 @@ class SymbolTable bool m_is_initialized{false}; - ASTNodeDataType m_data_type{ASTNodeDataType::undefined_t}; + ASTNodeDataType m_data_type; DataVariant m_value; public: @@ -267,7 +267,12 @@ class SymbolTable clearValues() { for (auto& symbol : m_symbol_list) { - symbol.attributes().value() = DataVariant{}; + std::visit( + [](auto&& value) { + using T = std::decay_t<decltype(value)>; + value = T{}; + }, + symbol.attributes().value()); } } diff --git a/src/language/utils/UnaryOperatorMangler.hpp b/src/language/utils/UnaryOperatorMangler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..380eb6fa0c5e260340eb2ab463f3a7567c6a301a --- /dev/null +++ b/src/language/utils/UnaryOperatorMangler.hpp @@ -0,0 +1,32 @@ +#ifndef UNARY_OPERATOR_MANGLER_HPP +#define UNARY_OPERATOR_MANGLER_HPP + +#include <language/utils/ASTNodeDataType.hpp> +#include <utils/Exceptions.hpp> + +#include <string> + +namespace language +{ +struct unary_minus; +struct unary_not; +} // namespace language + +template <typename UnaryOperatorT> +std::string +unaryOperatorMangler(const ASTNodeDataType& operand) +{ + const std::string operator_name = [] { + if constexpr (std::is_same_v<language::unary_minus, UnaryOperatorT>) { + return "-"; + } else if constexpr (std::is_same_v<language::unary_not, UnaryOperatorT>) { + return "not"; + } else { + static_assert(std::is_same_v<language::unary_minus, UnaryOperatorT>, "undefined unary operator"); + } + }(); + + return operator_name + " " + dataTypeName(operand); +} + +#endif // UNARY_OPERATOR_MANGLER_HPP diff --git a/src/language/utils/UnaryOperatorProcessorBuilder.hpp b/src/language/utils/UnaryOperatorProcessorBuilder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d99d2e17536de2b31a3ac2c48f39800a2461f17d --- /dev/null +++ b/src/language/utils/UnaryOperatorProcessorBuilder.hpp @@ -0,0 +1,31 @@ +#ifndef UNARY_OPERATOR_PROCESSOR_BUILDER_HPP +#define UNARY_OPERATOR_PROCESSOR_BUILDER_HPP + +#include <algebra/TinyVector.hpp> +#include <language/PEGGrammar.hpp> +#include <language/node_processor/UnaryExpressionProcessor.hpp> +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/IUnaryOperatorProcessorBuilder.hpp> + +#include <type_traits> + +template <typename OperatorT, typename ValueT, typename DataT> +class UnaryOperatorProcessorBuilder final : public IUnaryOperatorProcessorBuilder +{ + public: + UnaryOperatorProcessorBuilder() = default; + + ASTNodeDataType + getReturnValueType() const + { + return ast_node_data_type_from<ValueT>; + } + + std::unique_ptr<INodeProcessor> + getNodeProcessor(ASTNode& node) const + { + return std::make_unique<UnaryExpressionProcessor<OperatorT, ValueT, DataT>>(node); + } +}; + +#endif // UNARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/UnaryOperatorRegisterForB.cpp b/src/language/utils/UnaryOperatorRegisterForB.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33173e82b896e8d013f125a3b47ce42682ac132d --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForB.cpp @@ -0,0 +1,32 @@ +#include <language/utils/UnaryOperatorRegisterForB.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +void +UnaryOperatorRegisterForB::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto B = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + + repository.addUnaryOperator< + language::unary_minus>(B, std::make_shared<UnaryOperatorProcessorBuilder<language::unary_minus, int64_t, bool>>()); +} + +void +UnaryOperatorRegisterForB::_register_unary_not() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto B = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + + repository.addUnaryOperator< + language::unary_not>(B, std::make_shared<UnaryOperatorProcessorBuilder<language::unary_not, bool, bool>>()); +} + +UnaryOperatorRegisterForB::UnaryOperatorRegisterForB() +{ + this->_register_unary_minus(); + this->_register_unary_not(); +} diff --git a/src/language/utils/UnaryOperatorRegisterForB.hpp b/src/language/utils/UnaryOperatorRegisterForB.hpp new file mode 100644 index 0000000000000000000000000000000000000000..49f1f8f4e8cf02dab86d1e0b8c7008d03a137546 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForB.hpp @@ -0,0 +1,14 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_B_HPP +#define UNARY_OPERATOR_REGISTER_FOR_B_HPP + +class UnaryOperatorRegisterForB +{ + private: + void _register_unary_minus(); + void _register_unary_not(); + + public: + UnaryOperatorRegisterForB(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_B_HPP diff --git a/src/language/utils/UnaryOperatorRegisterForN.cpp b/src/language/utils/UnaryOperatorRegisterForN.cpp new file mode 100644 index 0000000000000000000000000000000000000000..baed0e43eef169229cfbff58fb0592f894b83163 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForN.cpp @@ -0,0 +1,21 @@ +#include <language/utils/UnaryOperatorRegisterForN.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +void +UnaryOperatorRegisterForN::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto N = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + + repository.addUnaryOperator< + language::unary_minus>(N, + std::make_shared<UnaryOperatorProcessorBuilder<language::unary_minus, int64_t, uint64_t>>()); +} + +UnaryOperatorRegisterForN::UnaryOperatorRegisterForN() +{ + this->_register_unary_minus(); +} diff --git a/src/language/utils/UnaryOperatorRegisterForN.hpp b/src/language/utils/UnaryOperatorRegisterForN.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d6a4ff62ed93e0db85b53f503de096f3dcc527ba --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForN.hpp @@ -0,0 +1,13 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_N_HPP +#define UNARY_OPERATOR_REGISTER_FOR_N_HPP + +class UnaryOperatorRegisterForN +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForN(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_N_HPP diff --git a/src/language/utils/UnaryOperatorRegisterForR.cpp b/src/language/utils/UnaryOperatorRegisterForR.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c0c60a31f5f6c222232999c3b392a18f30a767c1 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForR.cpp @@ -0,0 +1,20 @@ +#include <language/utils/UnaryOperatorRegisterForR.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +void +UnaryOperatorRegisterForR::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto R = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + + repository.addUnaryOperator< + language::unary_minus>(R, std::make_shared<UnaryOperatorProcessorBuilder<language::unary_minus, double, double>>()); +} + +UnaryOperatorRegisterForR::UnaryOperatorRegisterForR() +{ + this->_register_unary_minus(); +} diff --git a/src/language/utils/UnaryOperatorRegisterForR.hpp b/src/language/utils/UnaryOperatorRegisterForR.hpp new file mode 100644 index 0000000000000000000000000000000000000000..308fe639039da194c2419f3902a1f98868482716 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForR.hpp @@ -0,0 +1,13 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_R_HPP +#define UNARY_OPERATOR_REGISTER_FOR_R_HPP + +class UnaryOperatorRegisterForR +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForR(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_R_HPP diff --git a/src/language/utils/UnaryOperatorRegisterForRn.cpp b/src/language/utils/UnaryOperatorRegisterForRn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85cc03fa476a34a6bec65f2c7f88f1d76fbad809 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForRn.cpp @@ -0,0 +1,28 @@ +#include <language/utils/UnaryOperatorRegisterForRn.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +template <size_t Dimension> +void +UnaryOperatorRegisterForRn<Dimension>::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rn = ASTNodeDataType::build<ASTNodeDataType::vector_t>(Dimension); + + repository + .addUnaryOperator<language::unary_minus>(Rn, + std::make_shared<UnaryOperatorProcessorBuilder< + language::unary_minus, TinyVector<Dimension>, TinyVector<Dimension>>>()); +} + +template <size_t Dimension> +UnaryOperatorRegisterForRn<Dimension>::UnaryOperatorRegisterForRn() +{ + this->_register_unary_minus(); +} + +template class UnaryOperatorRegisterForRn<1>; +template class UnaryOperatorRegisterForRn<2>; +template class UnaryOperatorRegisterForRn<3>; diff --git a/src/language/utils/UnaryOperatorRegisterForRn.hpp b/src/language/utils/UnaryOperatorRegisterForRn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0000f2f588bdc84ccc77e3c03515a92126aac14 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForRn.hpp @@ -0,0 +1,16 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_RN_HPP +#define UNARY_OPERATOR_REGISTER_FOR_RN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class UnaryOperatorRegisterForRn +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForRn(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_RN_HPP diff --git a/src/language/utils/UnaryOperatorRegisterForRnxn.cpp b/src/language/utils/UnaryOperatorRegisterForRnxn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..798a60086b30d008901f0d9b5fa2502e745d6fe9 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForRnxn.cpp @@ -0,0 +1,28 @@ +#include <language/utils/UnaryOperatorRegisterForRnxn.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +template <size_t Dimension> +void +UnaryOperatorRegisterForRnxn<Dimension>::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Rnxn = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(Dimension, Dimension); + + repository + .addUnaryOperator<language::unary_minus>(Rnxn, + std::make_shared<UnaryOperatorProcessorBuilder< + language::unary_minus, TinyMatrix<Dimension>, TinyMatrix<Dimension>>>()); +} + +template <size_t Dimension> +UnaryOperatorRegisterForRnxn<Dimension>::UnaryOperatorRegisterForRnxn() +{ + this->_register_unary_minus(); +} + +template class UnaryOperatorRegisterForRnxn<1>; +template class UnaryOperatorRegisterForRnxn<2>; +template class UnaryOperatorRegisterForRnxn<3>; diff --git a/src/language/utils/UnaryOperatorRegisterForRnxn.hpp b/src/language/utils/UnaryOperatorRegisterForRnxn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..42084f2961504130189f40967b50757401b7b939 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForRnxn.hpp @@ -0,0 +1,16 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_RNXN_HPP +#define UNARY_OPERATOR_REGISTER_FOR_RNXN_HPP + +#include <cstdlib> + +template <size_t Dimension> +class UnaryOperatorRegisterForRnxn +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForRnxn(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_RNXN_HPP diff --git a/src/language/utils/UnaryOperatorRegisterForZ.cpp b/src/language/utils/UnaryOperatorRegisterForZ.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89e0b4972a0a1e2a2556d2210ab4658b35f7b976 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForZ.cpp @@ -0,0 +1,21 @@ +#include <language/utils/UnaryOperatorRegisterForZ.hpp> + +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> + +void +UnaryOperatorRegisterForZ::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + auto Z = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + + repository.addUnaryOperator< + language::unary_minus>(Z, + std::make_shared<UnaryOperatorProcessorBuilder<language::unary_minus, int64_t, int64_t>>()); +} + +UnaryOperatorRegisterForZ::UnaryOperatorRegisterForZ() +{ + this->_register_unary_minus(); +} diff --git a/src/language/utils/UnaryOperatorRegisterForZ.hpp b/src/language/utils/UnaryOperatorRegisterForZ.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0e3065b139a23e1898bc4954b3b76a1df3f7e44 --- /dev/null +++ b/src/language/utils/UnaryOperatorRegisterForZ.hpp @@ -0,0 +1,13 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_Z_HPP +#define UNARY_OPERATOR_REGISTER_FOR_Z_HPP + +class UnaryOperatorRegisterForZ +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForZ(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_Z_HPP diff --git a/src/mesh/Connectivity.hpp b/src/mesh/Connectivity.hpp index 3bf5cd4201c574749e5b6a67a2b6c4d667bd88f1..f8e4dafdb7a46473434645c26e46ee1c00f86117 100644 --- a/src/mesh/Connectivity.hpp +++ b/src/mesh/Connectivity.hpp @@ -12,7 +12,7 @@ #include <mesh/RefId.hpp> #include <mesh/RefItemList.hpp> #include <mesh/SubItemValuePerItem.hpp> -#include <utils/CSRGraph.hpp> +#include <utils/CRSGraph.hpp> #include <utils/Exceptions.hpp> #include <utils/PugsAssert.hpp> #include <utils/PugsMacros.hpp> @@ -542,36 +542,20 @@ class Connectivity final : public IConnectivity } PUGS_INLINE - CSRGraph + CRSGraph cellToCellGraph() const { std::vector<std::set<int>> cell_cells(this->numberOfCells()); - if constexpr (true) { - const auto& face_to_cell_matrix = this->faceToCellMatrix(); + const auto& face_to_cell_matrix = this->faceToCellMatrix(); - for (FaceId l = 0; l < this->numberOfFaces(); ++l) { - const auto& face_to_cell = face_to_cell_matrix[l]; - if (face_to_cell.size() > 1) { - const CellId cell_0 = face_to_cell[0]; - const CellId cell_1 = face_to_cell[1]; + for (FaceId l = 0; l < this->numberOfFaces(); ++l) { + const auto& face_to_cell = face_to_cell_matrix[l]; + if (face_to_cell.size() > 1) { + const CellId cell_0 = face_to_cell[0]; + const CellId cell_1 = face_to_cell[1]; - cell_cells[cell_0].insert(cell_1); - cell_cells[cell_1].insert(cell_0); - } - } - } else { - const auto& node_to_cell_matrix = this->nodeToCellMatrix(); - - for (NodeId l = 0; l < this->numberOfNodes(); ++l) { - const auto& node_to_cell = node_to_cell_matrix[l]; - for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) { - const CellId cell_0 = node_to_cell[i_cell]; - for (size_t j_cell = 0; j_cell < i_cell; ++j_cell) { - const CellId cell_1 = node_to_cell[j_cell]; - cell_cells[cell_0].insert(cell_1); - cell_cells[cell_1].insert(cell_0); - } - } + cell_cells[cell_0].insert(cell_1); + cell_cells[cell_1].insert(cell_0); } } @@ -590,7 +574,7 @@ class Connectivity final : public IConnectivity } } } - return CSRGraph(entries, neighbors); + return CRSGraph(entries, neighbors); } PUGS_INLINE diff --git a/src/mesh/ConnectivityDispatcher.cpp b/src/mesh/ConnectivityDispatcher.cpp index 82bf2a2e006a13d0806b2d88b1dc54715a0cc116..6b73d967d4d4fcb37e1376bf1d9306826d25fb1f 100644 --- a/src/mesh/ConnectivityDispatcher.cpp +++ b/src/mesh/ConnectivityDispatcher.cpp @@ -1,6 +1,7 @@ #include <mesh/ConnectivityDispatcher.hpp> -#include <mesh/ItemOfItemType.hpp> +#include <mesh/ItemOfItemType.hpp> +#include <utils/CRSGraph.hpp> #include <utils/Partitioner.hpp> #include <iostream> @@ -12,7 +13,7 @@ void ConnectivityDispatcher<Dimension>::_buildNewOwner() { if constexpr (item_type == ItemType::cell) { - CSRGraph connectivity_graph = m_connectivity.cellToCellGraph(); + CRSGraph connectivity_graph = m_connectivity.cellToCellGraph(); Partitioner P; CellValue<int> cell_new_owner(m_connectivity); diff --git a/src/mesh/Mesh.hpp b/src/mesh/Mesh.hpp index d04807650cd08c0b96a7259cbe5756596d54aad1..e78383978469292d6529eb931b432406612f7220 100644 --- a/src/mesh/Mesh.hpp +++ b/src/mesh/Mesh.hpp @@ -4,7 +4,6 @@ #include <algebra/TinyVector.hpp> #include <mesh/IMesh.hpp> #include <mesh/ItemValue.hpp> -#include <utils/CSRGraph.hpp> #include <memory> diff --git a/src/utils/Array.hpp b/src/utils/Array.hpp index 42ae7f61c375a9cb544020063f827c43dcf5e143..560f60ce7015c06fa68c80121b43f86179a05e83 100644 --- a/src/utils/Array.hpp +++ b/src/utils/Array.hpp @@ -10,7 +10,7 @@ #include <algorithm> template <typename DataType> -class Array +class [[nodiscard]] Array { public: using data_type = DataType; @@ -23,15 +23,12 @@ class Array friend Array<std::add_const_t<DataType>>; public: - PUGS_INLINE - size_t - size() const noexcept + PUGS_INLINE size_t size() const noexcept { return m_values.extent(0); } - friend PUGS_INLINE Array<std::remove_const_t<DataType>> - copy(const Array<DataType>& source) + friend PUGS_INLINE Array<std::remove_const_t<DataType>> copy(const Array<DataType>& source) { Array<std::remove_const_t<DataType>> image(source.size()); Kokkos::deep_copy(image.m_values, source.m_values); @@ -42,16 +39,14 @@ class Array template <typename DataType2, typename... RT> friend PUGS_INLINE Array<DataType2> encapsulate(const Kokkos::View<DataType2*, RT...>& values); - PUGS_INLINE - DataType& operator[](index_type i) const noexcept(NO_ASSERT) + PUGS_INLINE DataType& operator[](index_type i) const noexcept(NO_ASSERT) { Assert(i < m_values.extent(0)); return m_values[i]; } PUGS_INLINE - void - fill(const DataType& data) const + void fill(const DataType& data) const { static_assert(not std::is_const<DataType>(), "Cannot modify Array of const"); @@ -61,8 +56,7 @@ class Array } template <typename DataType2> - PUGS_INLINE Array& - operator=(const Array<DataType2>& array) noexcept + PUGS_INLINE Array& operator=(const Array<DataType2>& array) noexcept { // ensures that DataType is the same as source DataType2 static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), @@ -81,7 +75,7 @@ class Array Array& operator=(Array&&) = default; PUGS_INLINE - Array(size_t size) : m_values("anonymous", size) + explicit Array(size_t size) : m_values("anonymous", size) { static_assert(not std::is_const<DataType>(), "Cannot allocate Array of const data: only view is " "supported"); @@ -94,14 +88,13 @@ class Array Array(const Array&) = default; template <typename DataType2> - PUGS_INLINE - Array(const Array<DataType2>& array) noexcept + PUGS_INLINE Array(const Array<DataType2>& array) noexcept { this->operator=(array); } PUGS_INLINE - Array(Array&&) = default; + Array(Array &&) = default; PUGS_INLINE ~Array() = default; diff --git a/src/utils/BuildInfo.cpp b/src/utils/BuildInfo.cpp index 13b8633c2664add44d91b11b0a6d55afe79d2955..38d46007ef6cfcf470051e8546714db7041a7852 100644 --- a/src/utils/BuildInfo.cpp +++ b/src/utils/BuildInfo.cpp @@ -8,6 +8,10 @@ #include <mpi.h> #endif // PUGS_HAS_MPI +#ifdef PUGS_HAS_PETSC +#include <petsc.h> +#endif // PUGS_HAS_PETSC + std::string BuildInfo::type() { @@ -18,7 +22,7 @@ std::string BuildInfo::compiler() { std::stringstream compiler_info; - compiler_info << PUGS_BUILD_COMPILER << " (" << PUGS_BUILD_COMPILER_VERSION << ")" << std::ends; + compiler_info << PUGS_BUILD_COMPILER << " (" << PUGS_BUILD_COMPILER_VERSION << ")"; return compiler_info.str(); } @@ -42,3 +46,14 @@ BuildInfo::mpiLibrary() return "none"; #endif // PUGS_HAS_MPI } + +std::string +BuildInfo::petscLibrary() +{ +#ifdef PUGS_HAS_PETSC + return std::to_string(PETSC_VERSION_MAJOR) + "." + std::to_string(PETSC_VERSION_MINOR) + "." + + std::to_string(PETSC_VERSION_SUBMINOR); +#else + return "none"; +#endif // PUGS_HAS_PETSC +} diff --git a/src/utils/BuildInfo.hpp b/src/utils/BuildInfo.hpp index 81d2b4960085fbeffff667fb92e8a4069848fb8d..67134a782f595ccf68dc89cd29a6f273c7ce6140 100644 --- a/src/utils/BuildInfo.hpp +++ b/src/utils/BuildInfo.hpp @@ -9,6 +9,7 @@ struct BuildInfo static std::string compiler(); static std::string kokkosDevices(); static std::string mpiLibrary(); + static std::string petscLibrary(); }; #endif // BUILD_INFO_HPP diff --git a/src/utils/CRSGraph.hpp b/src/utils/CRSGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..59b74ab57b16db757e95d1b74e72b4e296bfc026 --- /dev/null +++ b/src/utils/CRSGraph.hpp @@ -0,0 +1,47 @@ +#ifndef CRS_GRAPH_HPP +#define CRS_GRAPH_HPP + +#include <utils/Array.hpp> + +class CRSGraph +{ + private: + Array<const int> m_entries; + Array<const int> m_neighbors; + + public: + size_t + numberOfNodes() const + { + Assert(m_entries.size() > 0); + return m_entries.size() - 1; + } + + const Array<const int>& + entries() const + { + return m_entries; + } + + const Array<const int>& + neighbors() const + { + return m_neighbors; + } + + CRSGraph& operator=(CRSGraph&&) = delete; + CRSGraph& operator=(const CRSGraph&) = delete; + + CRSGraph(const Array<int>& entries, const Array<int>& neighbors) : m_entries(entries), m_neighbors(neighbors) + { + Assert(m_entries.size() > 0); + Assert(static_cast<size_t>(m_entries[m_entries.size() - 1]) == m_neighbors.size()); + } + + CRSGraph() = delete; + CRSGraph(CRSGraph&&) = delete; + CRSGraph(const CRSGraph&) = delete; + ~CRSGraph() = default; +}; + +#endif // CRS_GRAPH_HPP diff --git a/src/utils/CSRGraph.hpp b/src/utils/CSRGraph.hpp deleted file mode 100644 index a3ae85dcced1897fc54901903c919f0009761d74..0000000000000000000000000000000000000000 --- a/src/utils/CSRGraph.hpp +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef CSR_GRAPH_HPP -#define CSR_GRAPH_HPP - -#include <utils/Array.hpp> - -class CSRGraph -{ - private: - Array<int> m_entries; - Array<int> m_neighbors; - - public: - size_t - numberOfNodes() const - { - Assert(m_entries.size() > 0); - return m_entries.size() - 1; - } - - const Array<int>& - entries() const - { - return m_entries; - } - - const Array<int>& - neighbors() const - { - return m_neighbors; - } - - CSRGraph& operator=(CSRGraph&&) = default; - CSRGraph& operator=(const CSRGraph&) = default; - - CSRGraph(const Array<int>& entries, const Array<int>& neighbors) : m_entries(entries), m_neighbors(neighbors) - { - Assert(m_entries.size() > 0); - } - - CSRGraph() = default; - CSRGraph(CSRGraph&&) = default; - CSRGraph(const CSRGraph&) = default; - ~CSRGraph() = default; -}; - -#endif // CSR_GRAPH_HPP diff --git a/src/utils/CastArray.hpp b/src/utils/CastArray.hpp index 120bcc25c1f5dd9da75a387016bad8d2b5175688..d329b4ef35f70f3a612ca430c45d0ec23bb7ed9d 100644 --- a/src/utils/CastArray.hpp +++ b/src/utils/CastArray.hpp @@ -8,7 +8,7 @@ #include <iostream> template <typename DataType, typename CastDataType> -class CastArray +class [[nodiscard]] CastArray { public: using data_type = CastDataType; @@ -20,8 +20,7 @@ class CastArray public: PUGS_INLINE - const size_t& - size() const + const size_t& size() const { return m_size; } @@ -39,17 +38,10 @@ class CastArray PUGS_INLINE CastArray& operator=(CastArray&&) = default; - PUGS_INLINE - CastArray() : m_size(0), m_values(nullptr) - { - ; - } - - PUGS_INLINE - CastArray(const Array<DataType>& array) + explicit CastArray(const Array<DataType>& array) : m_array(array), m_size(sizeof(DataType) * array.size() / sizeof(CastDataType)), - m_values((array.size() == 0) ? nullptr : reinterpret_cast<CastDataType*>(&(array[0]))) + m_values((array.size() == 0) ? nullptr : reinterpret_cast<CastDataType*>(&(static_cast<DataType&>(array[0])))) { static_assert((std::is_const_v<CastDataType> and std::is_const_v<DataType>) or (not std::is_const_v<DataType>), "CastArray cannot remove const attribute"); @@ -59,9 +51,8 @@ class CastArray } } - PUGS_INLINE - CastArray(DataType& value) - : m_size(sizeof(DataType) / sizeof(CastDataType)), m_values(reinterpret_cast<CastDataType*>(&(value))) + explicit CastArray(DataType & value) + : m_size(sizeof(DataType) / sizeof(CastDataType)), m_values(reinterpret_cast<CastDataType*>(&value)) { static_assert((std::is_const_v<CastDataType> and std::is_const_v<DataType>) or (not std::is_const_v<DataType>), "CastArray cannot remove const attribute"); @@ -70,13 +61,13 @@ class CastArray } PUGS_INLINE - CastArray(DataType&& value) = delete; + CastArray(DataType && value) = delete; PUGS_INLINE CastArray(const CastArray&) = default; PUGS_INLINE - CastArray(CastArray&&) = default; + CastArray(CastArray &&) = default; PUGS_INLINE ~CastArray() = default; diff --git a/src/utils/ConsoleManager.cpp b/src/utils/ConsoleManager.cpp index 666caa1b17c5bb397e592132febeabab4846732f..f55004ac7dd2181256404b486ad1eaa0c05abd85 100644 --- a/src/utils/ConsoleManager.cpp +++ b/src/utils/ConsoleManager.cpp @@ -11,12 +11,9 @@ ConsoleManager::isTerminal(std::ostream& os) void ConsoleManager::init(bool colorize) { - std::cout << "Console management: color "; if (colorize) { rang::setControlMode(rang::control::Force); - std::cout << rang::style::bold << rang::fgB::green << "enabled" << rang::fg::reset << rang::style::reset << '\n'; } else { rang::setControlMode(rang::control::Off); - std::cout << "disabled\n"; } } diff --git a/src/utils/EscapedString.hpp b/src/utils/EscapedString.hpp index ed525abd0368614fb756fd8c7795bb1c7ea178c6..cf0f0c676ebd2998673c62870b07e52bf04e0269 100644 --- a/src/utils/EscapedString.hpp +++ b/src/utils/EscapedString.hpp @@ -1,6 +1,7 @@ #ifndef ESCAPED_STRING_HPP #define ESCAPED_STRING_HPP +#include <utils/Exceptions.hpp> #include <utils/PugsMacros.hpp> #include <sstream> @@ -11,7 +12,7 @@ PUGS_INLINE std::string unescapeString(std::string_view input_string) { std::stringstream ss; - for (size_t i = 1; i < input_string.size() - 1; ++i) { + for (size_t i = 0; i < input_string.size(); ++i) { char c = input_string[i]; if (c == '\\') { ++i; @@ -81,11 +82,15 @@ escapeString(std::string_view input_string) ss << R"(\\)"; break; } + case '\'': { + ss << R"(\')"; + break; + } case '\"': { ss << R"(\")"; break; } - case '?': { + case '\?': { ss << R"(\?)"; break; } diff --git a/src/utils/FPEManager.cpp b/src/utils/FPEManager.cpp index 7430b5bbcc8c17d31b3a9508a170bfb3ed525632..89a75d7c77e8649b6710e32f6d89ed5e9f85c1c0 100644 --- a/src/utils/FPEManager.cpp +++ b/src/utils/FPEManager.cpp @@ -62,16 +62,12 @@ fedisableexcept(unsigned int excepts) void FPEManager::enable() { - std::cout << "FE management: " << rang::style::bold << rang::fgB::green << "enabled" << rang::fg::reset - << rang::style::reset << '\n'; ::feenableexcept(MANAGED_FPE); } void FPEManager::disable() { - std::cout << "FE management: " << rang::style::bold << rang::fgB::red << "disabled" << rang::fg::reset - << rang::style::reset << '\n'; ::fedisableexcept(MANAGED_FPE); } @@ -79,15 +75,11 @@ FPEManager::disable() void FPEManager::enable() -{ - std::cout << "FE management: enabled " << rang::fg::red << "[not supported]" << rang::fg::reset << '\n'; -} +{} void FPEManager::disable() -{ - std::cout << "FE management: disable " << rang::fg::red << "[not supported]" << rang::fg::reset << '\n'; -} +{} #endif // PUGS_HAS_FENV_H diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 437ea62a0d734274900c8af1b474749813e1aeaa..cfe312a0884ae0718ee76c41855ce8b063fecffd 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -168,7 +168,7 @@ class Messenger typename RecvArrayType, typename... SendT, typename... RecvT> - RecvArrayType<RecvT...> + void _allToAll(const SendArrayType<SendT...>& sent_array, RecvArrayType<RecvT...>& recv_array) const { #ifdef PUGS_HAS_MPI @@ -189,7 +189,6 @@ class Messenger #else // PUGS_HAS_MPI value_copy(sent_array, recv_array); #endif // PUGS_HAS_MPI - return recv_array; } template <template <typename... SendT> typename SendArrayType, diff --git a/src/utils/Partitioner.cpp b/src/utils/Partitioner.cpp index 43ee64851b8d63e80afaf4d6960fe8015c1fa675..347b07da6bffdf62d0218ca5231ed9e03a9c698f 100644 --- a/src/utils/Partitioner.cpp +++ b/src/utils/Partitioner.cpp @@ -15,7 +15,7 @@ #include <utils/Exceptions.hpp> Array<int> -Partitioner::partition(const CSRGraph& graph) +Partitioner::partition(const CRSGraph& graph) { std::cout << "Partitioning graph into " << rang::style::bold << parallel::size() << rang::style::reset << " parts\n"; @@ -26,7 +26,7 @@ Partitioner::partition(const CSRGraph& graph) std::vector<float> tpwgts(npart, 1. / npart); std::vector<float> ubvec{1.05}; - std::vector<int> options{1, 1, 0}; + std::vector<int> options{1, 0, 0}; int edgecut = 0; Array<int> part(0); @@ -57,15 +57,20 @@ Partitioner::partition(const CSRGraph& graph) part = Array<int>(local_number_of_nodes); std::vector<int> vtxdist{0, local_number_of_nodes}; - const Array<int>& entries = graph.entries(); - const Array<int>& neighbors = graph.neighbors(); + const Array<const int>& entries = graph.entries(); + const Array<const int>& neighbors = graph.neighbors(); + + int* entries_ptr = const_cast<int*>(&(entries[0])); + int* neighbors_ptr = const_cast<int*>(&(neighbors[0])); int result = - ParMETIS_V3_PartKway(&(vtxdist[0]), &(entries[0]), &(neighbors[0]), NULL, NULL, &wgtflag, &numflag, &ncon, &npart, + ParMETIS_V3_PartKway(&(vtxdist[0]), entries_ptr, neighbors_ptr, NULL, NULL, &wgtflag, &numflag, &ncon, &npart, &(tpwgts[0]), &(ubvec[0]), &(options[0]), &edgecut, &(part[0]), &parmetis_comm); + // LCOV_EXCL_START if (result == METIS_ERROR) { throw UnexpectedError("Metis Error"); } + // LCOV_EXCL_STOP MPI_Comm_free(&parmetis_comm); } @@ -78,9 +83,11 @@ Partitioner::partition(const CSRGraph& graph) #else // PUGS_HAS_MPI Array<int> -Partitioner::partition(const CSRGraph&) +Partitioner::partition(const CRSGraph& graph) { - return Array<int>(0); + Array<int> partition{graph.entries().size() - 1}; + partition.fill(0); + return partition; } #endif // PUGS_HAS_MPI diff --git a/src/utils/Partitioner.hpp b/src/utils/Partitioner.hpp index 30f50723acff089cd1a21fa4f289030f5bbdf7a3..2c720bfd87e7853e12cd0736a46c5eda246f085f 100644 --- a/src/utils/Partitioner.hpp +++ b/src/utils/Partitioner.hpp @@ -1,7 +1,7 @@ #ifndef PARTITIONER_HPP #define PARTITIONER_HPP -#include <utils/CSRGraph.hpp> +#include <utils/CRSGraph.hpp> class Partitioner { @@ -10,7 +10,7 @@ class Partitioner Partitioner(const Partitioner&) = default; ~Partitioner() = default; - Array<int> partition(const CSRGraph& graph); + Array<int> partition(const CRSGraph& graph); }; #endif // PARTITIONER_HPP diff --git a/src/utils/PugsAssert.hpp b/src/utils/PugsAssert.hpp index def5604de9cd870e0e2abcd67f228b7ce0143f2e..7ac521a51f73df6c5129ddd6ba346bd06453c07f 100644 --- a/src/utils/PugsAssert.hpp +++ b/src/utils/PugsAssert.hpp @@ -102,16 +102,15 @@ struct AssertChecker<std::tuple<Args...> > #else // NDEBUG -#define Assert(...) \ - { \ - using TupleArgs = decltype(std::make_tuple(__VA_ARGS__)); \ - AssertChecker<TupleArgs>::check_args_type(); \ - constexpr int tuple_size = std::tuple_size_v<TupleArgs>; \ - static_assert(tuple_size >= 1 and tuple_size <= 2); \ - auto args = std::forward_as_tuple(__VA_ARGS__); \ - if (not static_cast<bool>(std::get<0>(args))) { \ - throw AssertError(__FILE__, __LINE__, __PRETTY_FUNCTION__, args, #__VA_ARGS__); \ - } \ +#define Assert(...) \ + { \ + using TupleArgs = decltype(std::make_tuple(__VA_ARGS__)); \ + AssertChecker<TupleArgs>::check_args_type(); \ + constexpr int tuple_size = std::tuple_size_v<TupleArgs>; \ + static_assert(tuple_size >= 1 and tuple_size <= 2); \ + if (not static_cast<bool>(std::get<0>(std::forward_as_tuple(__VA_ARGS__)))) { \ + throw AssertError(__FILE__, __LINE__, __PRETTY_FUNCTION__, std::forward_as_tuple(__VA_ARGS__), #__VA_ARGS__); \ + } \ } #endif // NDEBUG diff --git a/src/utils/PugsTraits.hpp b/src/utils/PugsTraits.hpp index e34e0433bb69316af695fec1c8d571ef23e03155..f51b455cb5ea4efa678dd980d1fec2721ab1fd2a 100644 --- a/src/utils/PugsTraits.hpp +++ b/src/utils/PugsTraits.hpp @@ -84,7 +84,14 @@ inline constexpr bool is_tiny_vector_v = false; template <size_t N, typename T> inline constexpr bool is_tiny_vector_v<TinyVector<N, T>> = true; -// Traits is_tiny_vector +// Traits is_tiny_matrix + +template <typename T> +inline constexpr bool is_tiny_matrix_v = false; + +template <size_t N, typename T> +inline constexpr bool is_tiny_matrix_v<TinyMatrix<N, T>> = true; + // helper to check if a type is part of a variant template <typename T, typename V> diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp index 7ea50b0ed95845d2cee2633477e33080453f0f30..ad9e12e7ec4cdd01305cf8701c327894c7942304 100644 --- a/src/utils/PugsUtils.cpp +++ b/src/utils/PugsUtils.cpp @@ -1,13 +1,13 @@ #include <utils/PugsUtils.hpp> +#include <algebra/PETScWrapper.hpp> #include <utils/BuildInfo.hpp> -#include <utils/RevisionInfo.hpp> - -#include <utils/Messenger.hpp> - #include <utils/ConsoleManager.hpp> #include <utils/FPEManager.hpp> +#include <utils/Messenger.hpp> +#include <utils/RevisionInfo.hpp> #include <utils/SignalManager.hpp> +#include <utils/pugs_build_info.hpp> #include <rang.hpp> @@ -18,37 +18,78 @@ #include <iostream> std::string -initialize(int& argc, char* argv[]) +pugsVersion() { - parallel::Messenger::create(argc, argv); + std::stringstream os; - std::cout << "Pugs version: " << rang::style::bold << RevisionInfo::version() << rang::style::reset << '\n'; + os << "pugs version: " << rang::style::bold << RevisionInfo::version() << rang::style::reset << '\n'; - std::cout << "-------------------- " << rang::fg::green << "git info" << rang::fg::reset - << " -------------------------" << '\n'; - std::cout << "tag: " << rang::style::bold << RevisionInfo::gitTag() << rang::style::reset << '\n'; - std::cout << "HEAD: " << rang::style::bold << RevisionInfo::gitHead() << rang::style::reset << '\n'; - std::cout << "hash: " << rang::style::bold << RevisionInfo::gitHash() << rang::style::reset << " ("; + os << "-------------------- " << rang::fg::green << "git info" << rang::fg::reset << " -------------------------" + << '\n'; + os << "tag: " << rang::style::bold << RevisionInfo::gitTag() << rang::style::reset << '\n'; + os << "HEAD: " << rang::style::bold << RevisionInfo::gitHead() << rang::style::reset << '\n'; + os << "hash: " << rang::style::bold << RevisionInfo::gitHash() << rang::style::reset << " ("; + // LCOV_EXCL_START Cannot cover both situations at same time if (RevisionInfo::gitIsClean()) { - std::cout << rang::fgB::green << "clean" << rang::fg::reset; + os << rang::fgB::green << "clean" << rang::fg::reset; } else { - std::cout << rang::fgB::red << "dirty" << rang::fg::reset; + os << rang::fgB::red << "dirty" << rang::fg::reset; } - std::cout << ")\n"; - std::cout << "-------------------- " << rang::fg::green << "build info" << rang::fg::reset - << " -----------------------" << '\n'; - std::cout << "type: " << rang::style::bold << BuildInfo::type() << rang::style::reset << '\n'; - std::cout << "compiler: " << rang::style::bold << BuildInfo::compiler() << rang::style::reset << '\n'; - std::cout << "kokkos: " << rang::style::bold << BuildInfo::kokkosDevices() << rang::style::reset << '\n'; - std::cout << "mpi: " << rang::style::bold << BuildInfo::mpiLibrary() << rang::style::reset << '\n'; - std::cout << "-------------------------------------------------------\n"; + // LCOV_EXCL_STOP + os << ")\n"; + os << "-------------------------------------------------------"; + + return os.str(); +} + +std::string +pugsBuildInfo() +{ + std::ostringstream os; + + os << "-------------------- " << rang::fg::green << "build info" << rang::fg::reset << " -----------------------" + << '\n'; + os << "type: " << rang::style::bold << BuildInfo::type() << rang::style::reset << '\n'; + os << "compiler: " << rang::style::bold << BuildInfo::compiler() << rang::style::reset << '\n'; + os << "kokkos: " << rang::style::bold << BuildInfo::kokkosDevices() << rang::style::reset << '\n'; + os << "MPI: " << rang::style::bold << BuildInfo::mpiLibrary() << rang::style::reset << '\n'; + os << "PETSc: " << rang::style::bold << BuildInfo::petscLibrary() << rang::style::reset << '\n'; + os << "-------------------------------------------------------"; + + return os.str(); +} + +void +setDefaultOMPEnvironment() +{ + if constexpr (std::string_view{PUGS_BUILD_KOKKOS_DEVICES} == std::string_view{"OpenMP"}) { + setenv("OMP_PROC_BIND", "spread", 0); + setenv("OMP_PLACES", "threads", 0); + } +} + +// LCOV_EXCL_START + +// This function cannot be unit-tested: run once when pugs starts + +std::string +initialize(int& argc, char* argv[]) +{ + parallel::Messenger::create(argc, argv); std::string filename; { - CLI::App app{"Pugs help"}; + CLI::App app{"pugs help"}; - app.add_option("filename", filename, "pugs script file")->required()->check(CLI::ExistingFile); + app.add_option("filename", filename, "pugs script file")->check(CLI::ExistingFile)->required(); + + app.set_version_flag("-v,--version", []() { + ConsoleManager::init(true); + std::stringstream os; + os << pugsVersion() << '\n' << pugsBuildInfo(); + return os.str(); + }); int threads = -1; app.add_option("--threads", threads, "Number of Kokkos threads") @@ -81,11 +122,19 @@ initialize(int& argc, char* argv[]) SignalManager::init(enable_signals); } + PETScWrapper::initialize(argc, argv); + + setDefaultOMPEnvironment(); Kokkos::initialize(argc, argv); - std::cout << "-------------------- " << rang::fg::green << "exec info" << rang::fg::reset - << " ------------------------" << '\n'; + std::cout << "----------------- " << rang::fg::green << "pugs exec info" << rang::fg::reset + << " ----------------------" << '\n'; std::cout << rang::style::bold; +#ifdef PUGS_HAS_MPI + std::cout << "MPI number of ranks " << parallel::size() << '\n'; +#else // PUGS_HAS_MPI + std::cout << "Sequential build\n"; +#endif // PUGS_HAS_MPI Kokkos::DefaultExecutionSpace::print_configuration(std::cout); std::cout << rang::style::reset; std::cout << "-------------------------------------------------------\n"; @@ -93,9 +142,18 @@ initialize(int& argc, char* argv[]) return filename; } +// LCOV_EXCL_STOP + +// LCOV_EXCL_START + +// This function cannot be unit-tested: run once when pugs stops + void finalize() { Kokkos::finalize(); + PETScWrapper::finalize(); parallel::Messenger::destroy(); } + +// LCOV_EXCL_STOP diff --git a/src/utils/PugsUtils.hpp b/src/utils/PugsUtils.hpp index ef4eeb72c340b8a6ca654d001c32390973b63c55..7429a75b3be20ec3d62028cc9180a2017f7b4f15 100644 --- a/src/utils/PugsUtils.hpp +++ b/src/utils/PugsUtils.hpp @@ -21,6 +21,12 @@ parallel_reduce(size_t size, const ArrayType& array, ReturnType& value, const st Kokkos::parallel_reduce(label, size, array, value); } +void setDefaultOMPEnvironment(); + +std::string pugsBuildInfo(); + +std::string pugsVersion(); + std::string initialize(int& argc, char* argv[]); void finalize(); diff --git a/src/utils/SignalManager.cpp b/src/utils/SignalManager.cpp index ab08424fa2731bc62d84b461cea850878da9514b..399f3164c2046f4ec25a009160a6f01b11da85ef 100644 --- a/src/utils/SignalManager.cpp +++ b/src/utils/SignalManager.cpp @@ -114,11 +114,5 @@ SignalManager::init(bool enable) std::signal(SIGINT, SignalManager::handler); std::signal(SIGABRT, SignalManager::handler); std::signal(SIGPIPE, SignalManager::handler); - - std::cout << "Signal management: " << rang::style::bold << rang::fgB::green << "enabled" << rang::fg::reset - << rang::style::reset << '\n'; - } else { - std::cout << "Signal management: " << rang::style::bold << rang::fgB::red << "disabled" << rang::fg::reset - << rang::style::reset << '\n'; } } diff --git a/src/utils/Timer.hpp b/src/utils/Timer.hpp index c26ea40aa12526b858d6e3a998102c9cd71d7ecd..786d7b789fb02993e2279f5e557c22423663a4e9 100644 --- a/src/utils/Timer.hpp +++ b/src/utils/Timer.hpp @@ -1,8 +1,113 @@ #ifndef TIMER_HPP #define TIMER_HPP -#include <Kokkos_Timer.hpp> +#include <chrono> +#include <iostream> -using Timer = Kokkos::Timer; +class Timer +{ + public: + enum class Status + { + running, + paused, + stopped + }; + + private: + std::chrono::time_point<std::chrono::high_resolution_clock> m_start; + std::chrono::duration<double> m_elapsed_sum; + + Status m_status; + + public: + Status + status() const + { + return m_status; + } + + double + seconds() const + { + switch (m_status) { + case Status::running: { + return (m_elapsed_sum + std::chrono::duration<double>{std::chrono::system_clock::now().time_since_epoch() - + m_start.time_since_epoch()}) + .count(); + } + case Status::paused: + case Status::stopped: { + return m_elapsed_sum.count(); + } + // LCOV_EXCL_START + default: { + return 0; + } + // LCOV_EXCL_STOP + } + } + + friend std::ostream& + operator<<(std::ostream& os, const Timer& timer) + { + os << timer.seconds() << 's'; + return os; + } + + void + reset() + { + m_start = std::chrono::high_resolution_clock::now(); + m_elapsed_sum = std::chrono::duration<double>::zero(); + } + + void + stop() + { + m_start = std::chrono::high_resolution_clock::now(); + m_elapsed_sum = std::chrono::duration<double>::zero(); + m_status = Status::stopped; + } + + void + pause() + { + if (m_status == Status::running) { + m_elapsed_sum += std::chrono::high_resolution_clock::now() - m_start; + m_start = std::chrono::high_resolution_clock::now(); + m_status = Status::paused; + } + } + + void + start() + { + switch (m_status) { + case Status::running: { + return; + } + case Status::paused: + case Status::stopped: { + m_start = std::chrono::high_resolution_clock::now(); + m_status = Status::running; + } + } + } + + Timer& operator=(const Timer&) = default; + Timer& operator=(Timer&&) = default; + + Timer(const Timer&) = default; + Timer(Timer&&) = default; + + Timer() + : m_start{std::chrono::high_resolution_clock::now()}, + m_elapsed_sum{std::chrono::duration<double>::zero()}, + m_status{Status::running} + {} + + ~Timer() = default; +}; #endif // TIMER_HPP diff --git a/src/utils/Types.hpp b/src/utils/Types.hpp index 482190cb6f4c34c614a1237c4c025bbb37812e20..2d53c59f17d00fbd44d92ebc5db06e529a5d47db 100644 --- a/src/utils/Types.hpp +++ b/src/utils/Types.hpp @@ -5,12 +5,12 @@ enum class ZeroType { zero }; -constexpr ZeroType zero = ZeroType::zero; +constexpr inline ZeroType zero = ZeroType::zero; enum class IdentityType { identity }; -constexpr IdentityType identity = IdentityType::identity; +constexpr inline IdentityType identity = IdentityType::identity; #endif // TYPES_HPP diff --git a/src/utils/pugs_config.hpp.in b/src/utils/pugs_config.hpp.in index 45878e64327797cd6549b18f4d6922f8cde8a29d..de267d26e10fb41994bbc72438b57d577e4b447a 100644 --- a/src/utils/pugs_config.hpp.in +++ b/src/utils/pugs_config.hpp.in @@ -3,11 +3,13 @@ #cmakedefine PUGS_HAS_FENV_H #cmakedefine PUGS_HAS_MPI +#cmakedefine PUGS_HAS_PETSC #cmakedefine SYSTEM_IS_LINUX #cmakedefine SYSTEM_IS_DARWIN #cmakedefine SYSTEM_IS_WINDOWS #define PUGS_BUILD_TYPE "@CMAKE_BUILD_TYPE@" +#define PUGS_BINARY_DIR "@PUGS_BINARY_DIR@" -#endif // PUGS_CONFIG_HPP +#endif // PUGS_CONFIG_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 428031d6839a14e74c9fde0b65ab72d98d0500c6..faa721f355752c4a85fe36026b1be85d1225714f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,6 +8,7 @@ add_executable (unit_tests test_main.cpp test_AffectationProcessor.cpp test_AffectationToStringProcessor.cpp + test_AffectationToTupleProcessor.cpp test_Array.cpp test_ArraySubscriptProcessor.cpp test_ArrayUtils.cpp @@ -44,20 +45,28 @@ add_executable (unit_tests test_BinaryExpressionProcessor_equality.cpp test_BinaryExpressionProcessor_logic.cpp test_BiCGStab.cpp + test_BuildInfo.cpp test_BuiltinFunctionEmbedder.cpp test_BuiltinFunctionEmbedderTable.cpp test_BuiltinFunctionProcessor.cpp - test_MathModule.cpp + test_CastArray.cpp + test_ConsoleManager.cpp + test_CG.cpp test_ContinueProcessor.cpp test_ConcatExpressionProcessor.cpp + test_CRSGraph.cpp test_CRSMatrix.cpp test_DataVariant.cpp + test_Demangle.cpp test_DiscontinuousGalerkin1D.cpp test_DoWhileProcessor.cpp test_EmbeddedData.cpp + test_EscapedString.cpp + test_Exceptions.cpp test_ExecutionPolicy.cpp test_FakeProcessor.cpp test_ForProcessor.cpp + test_FunctionArgumentConverter.cpp test_FunctionProcessor.cpp test_FunctionSymbolId.cpp test_FunctionTable.cpp @@ -65,17 +74,22 @@ add_executable (unit_tests test_IncDecExpressionProcessor.cpp test_INodeProcessor.cpp test_ItemType.cpp + test_LinearSolver.cpp + test_LinearSolverOptions.cpp test_ListAffectationProcessor.cpp + test_MathModule.cpp test_NameProcessor.cpp test_OStreamProcessor.cpp - test_PCG.cpp + test_ParseError.cpp test_Polynomial.cpp test_PolynomialBasis.cpp - test_PugsFunctionAdapter.cpp test_PugsAssert.cpp + test_PugsFunctionAdapter.cpp + test_PugsUtils.cpp test_RevisionInfo.cpp test_SparseMatrixDescriptor.cpp test_SymbolTable.cpp + test_Timer.cpp test_TinyMatrix.cpp test_TinyVector.cpp test_TupleToVectorProcessor.cpp @@ -86,30 +100,53 @@ add_executable (unit_tests add_executable (mpi_unit_tests mpi_test_main.cpp - mpi_test_Messenger.cpp + test_Messenger.cpp + test_Partitioner.cpp ) +add_library(test_Pugs_MeshDataBase + MeshDataBaseForTests.cpp) + target_link_libraries (unit_tests - PugsLanguage + test_Pugs_MeshDataBase PugsLanguageAST PugsLanguageModules PugsLanguageAlgorithms PugsLanguageUtils + PugsLanguage PugsMesh + PugsAlgebra PugsUtils kokkos ${PARMETIS_LIBRARIES} ${MPI_CXX_LINK_FLAGS} ${MPI_CXX_LIBRARIES} + ${PETSC_LIBRARIES} Catch2 + ${PUGS_STD_LINK_FLAGS} + stdc++fs ) target_link_libraries (mpi_unit_tests + test_Pugs_MeshDataBase + PugsAlgebra + PugsUtils + PugsLanguage + PugsLanguageAST + PugsLanguageModules + PugsLanguageAlgorithms + PugsMesh + PugsAlgebra PugsUtils + PugsLanguageUtils PugsUtils + PugsAlgebra PugsMesh kokkos ${PARMETIS_LIBRARIES} ${MPI_CXX_LINK_FLAGS} ${MPI_CXX_LIBRARIES} + ${PETSC_LIBRARIES} Catch2 + ${PUGS_STD_LINK_FLAGS} + stdc++fs ) enable_testing() diff --git a/tests/MeshDataBaseForTests.cpp b/tests/MeshDataBaseForTests.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8debe252a86fac67463ffc74ab17e1f9ccf141ea --- /dev/null +++ b/tests/MeshDataBaseForTests.cpp @@ -0,0 +1,55 @@ +#include <MeshDataBaseForTests.hpp> +#include <mesh/CartesianMeshBuilder.hpp> +#include <mesh/Connectivity.hpp> +#include <utils/PugsAssert.hpp> + +const MeshDataBaseForTests* MeshDataBaseForTests::m_instance = nullptr; + +MeshDataBaseForTests::MeshDataBaseForTests() +{ + m_cartesian_1d_mesh = CartesianMeshBuilder{TinyVector<1>{-1}, TinyVector<1>{3}, TinyVector<1, size_t>{23}}.mesh(); + + m_cartesian_2d_mesh = + CartesianMeshBuilder{TinyVector<2>{0, -1}, TinyVector<2>{3, 2}, TinyVector<2, size_t>{6, 7}}.mesh(); + + m_cartesian_3d_mesh = + CartesianMeshBuilder{TinyVector<3>{0, 1, 0}, TinyVector<3>{2, -1, 3}, TinyVector<3, size_t>{6, 7, 4}}.mesh(); +} + +const MeshDataBaseForTests& +MeshDataBaseForTests::get() +{ + return *m_instance; +} + +void +MeshDataBaseForTests::create() +{ + Assert(m_instance == nullptr); + m_instance = new MeshDataBaseForTests(); +} + +void +MeshDataBaseForTests::destroy() +{ + Assert(m_instance != nullptr); + delete m_instance; + m_instance = nullptr; +} + +template <size_t Dimension> +const Mesh<Connectivity<Dimension>>& +MeshDataBaseForTests::cartesianMesh() const +{ + if constexpr (Dimension == 1) { + return dynamic_cast<const Mesh<Connectivity<Dimension>>&>(*m_cartesian_1d_mesh); + } else if constexpr (Dimension == 2) { + return dynamic_cast<const Mesh<Connectivity<Dimension>>&>(*m_cartesian_2d_mesh); + } else if constexpr (Dimension == 3) { + return dynamic_cast<const Mesh<Connectivity<Dimension>>&>(*m_cartesian_3d_mesh); + } +} + +template const Mesh<Connectivity<1>>& MeshDataBaseForTests::cartesianMesh<1>() const; +template const Mesh<Connectivity<2>>& MeshDataBaseForTests::cartesianMesh<2>() const; +template const Mesh<Connectivity<3>>& MeshDataBaseForTests::cartesianMesh<3>() const; diff --git a/tests/MeshDataBaseForTests.hpp b/tests/MeshDataBaseForTests.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2a1b1c06c864f37981fd275067e0994a589afaac --- /dev/null +++ b/tests/MeshDataBaseForTests.hpp @@ -0,0 +1,36 @@ +#ifndef MESH_DATA_BASE_FOR_TESTS_HPP +#define MESH_DATA_BASE_FOR_TESTS_HPP + +#include <mesh/IMesh.hpp> + +template <size_t Dimension> +class Connectivity; + +template <typename ConnectivityT> +class Mesh; + +#include <memory> + +class MeshDataBaseForTests +{ + private: + explicit MeshDataBaseForTests(); + + static const MeshDataBaseForTests* m_instance; + + std::shared_ptr<const IMesh> m_cartesian_1d_mesh; + std::shared_ptr<const IMesh> m_cartesian_2d_mesh; + std::shared_ptr<const IMesh> m_cartesian_3d_mesh; + + public: + template <size_t Dimension> + const Mesh<Connectivity<Dimension>>& cartesianMesh() const; + + static const MeshDataBaseForTests& get(); + static void create(); + static void destroy(); + + ~MeshDataBaseForTests() = default; +}; + +#endif // MESH_DATA_BASE_FOR_TESTS_HPP diff --git a/tests/mpi_test_main.cpp b/tests/mpi_test_main.cpp index 565e46690910c05d6800ca1931bd811db790cc12..44162aba3edce0e992b7f37e6df6cca1040b3291 100644 --- a/tests/mpi_test_main.cpp +++ b/tests/mpi_test_main.cpp @@ -3,9 +3,19 @@ #include <Kokkos_Core.hpp> +#include <algebra/PETScWrapper.hpp> +#include <language/utils/OperatorRepository.hpp> +#include <mesh/DiamondDualConnectivityManager.hpp> +#include <mesh/DiamondDualMeshManager.hpp> +#include <mesh/MeshDataManager.hpp> +#include <mesh/SynchronizerManager.hpp> #include <utils/Messenger.hpp> +#include <utils/pugs_config.hpp> + +#include <MeshDataBaseForTests.hpp> #include <cstdlib> +#include <filesystem> int main(int argc, char* argv[]) @@ -13,21 +23,79 @@ main(int argc, char* argv[]) parallel::Messenger::create(argc, argv); Kokkos::initialize({4, -1, -1, true}); - // Disable outputs from tested classes to the standard output - std::cout.setstate(std::ios::badbit); + PETScWrapper::initialize(argc, argv); + + const std::string output_base_name{"mpi_test_rank_"}; + + std::filesystem::path parallel_output(std::string{PUGS_BINARY_DIR}); + + std::filesystem::path gcov_prefix = [&]() -> std::filesystem::path { + std::string template_temp_dir = std::filesystem::temp_directory_path() / "pugs_gcov_XXXXXX"; + return std::filesystem::path{mkdtemp(&template_temp_dir[0])}; + }(); - if (parallel::rank() != 0) { - setenv("GCOV_PREFIX", "/dev/null", 1); - } Catch::Session session; int result = session.applyCommandLine(argc, argv); if (result == 0) { - // Disable outputs from tested classes to the standard output - std::cout.setstate(std::ios::badbit); - result = session.run(); + const auto& config = session.config(); + if (config.listReporters() or config.listTags() or config.listTestNamesOnly() or config.listTests()) { + if (parallel::rank() == 0) { + session.run(); + } + } else { + if (parallel::rank() != 0) { + // Disable outputs for ranks != 0 + setenv("GCOV_PREFIX", gcov_prefix.string().c_str(), 1); + parallel_output /= output_base_name + std::to_string(parallel::rank()); + + Catch::ConfigData data{session.configData()}; + data.outputFilename = parallel_output.string(); + session.useConfigData(data); + } + + // Disable outputs from tested classes to the standard output + std::cout.setstate(std::ios::badbit); + + SynchronizerManager::create(); + MeshDataManager::create(); + DiamondDualConnectivityManager::create(); + DiamondDualMeshManager::create(); + + MeshDataBaseForTests::create(); + + if (parallel::rank() == 0) { + if (parallel::size() > 1) { + session.config().stream() << rang::fgB::green << "Other rank outputs are stored in corresponding files" + << rang::style::reset << '\n'; + + for (size_t i_rank = 1; i_rank < parallel::size(); ++i_rank) { + std::filesystem::path parallel_output(std::string{PUGS_BINARY_DIR}); + parallel_output /= output_base_name + std::to_string(i_rank); + session.config().stream() << " - " << rang::fg::green << parallel_output.parent_path().string() + << parallel_output.preferred_separator << rang::style::reset << rang::fgB::green + << parallel_output.filename().string() << rang::style::reset << '\n'; + } + } + } + + OperatorRepository::create(); + + result = session.run(); + + OperatorRepository::destroy(); + + MeshDataBaseForTests::destroy(); + + DiamondDualMeshManager::destroy(); + DiamondDualConnectivityManager::destroy(); + MeshDataManager::destroy(); + SynchronizerManager::destroy(); + } } + PETScWrapper::finalize(); + Kokkos::finalize(); parallel::Messenger::destroy(); diff --git a/tests/test_ASTBuilder.cpp b/tests/test_ASTBuilder.cpp index 2399a78617d6a6f30774fbc70be735d5850967d3..7353fa9a46d6d08542c62d4bee89299c930fc5a1 100644 --- a/tests/test_ASTBuilder.cpp +++ b/tests/test_ASTBuilder.cpp @@ -620,6 +620,41 @@ clog << "log " << l << "\n"; +-(language::literal:"log ") +-(language::name:l) `-(language::literal:"\n") +)"; + CHECK_AST(data, result); + } + + SECTION("tuple list simplification") + { + std::string_view data = R"( +let x:(R^2), x=((0,0),(2,3)); +let y:(R^2), y=((0)); +)"; + + std::string_view result = R"( +(root) + +-(language::var_declaration) + | +-(language::name:x) + | +-(language::tuple_type_specifier) + | | `-(language::vector_type) + | | +-(language::R_set) + | | `-(language::integer:2) + | +-(language::name:x) + | `-(language::expression_list) + | +-(language::tuple_expression) + | | +-(language::integer:0) + | | `-(language::integer:0) + | `-(language::tuple_expression) + | +-(language::integer:2) + | `-(language::integer:3) + `-(language::var_declaration) + +-(language::name:y) + +-(language::tuple_type_specifier) + | `-(language::vector_type) + | +-(language::R_set) + | `-(language::integer:2) + +-(language::name:y) + `-(language::integer:0) )"; CHECK_AST(data, result); } diff --git a/tests/test_ASTModulesImporter.cpp b/tests/test_ASTModulesImporter.cpp index 42638d918a362516b53a0f05772096b477f943b4..a3a4e3d7a6bd6bfc7d624b9d1ee4a9d69e34b876 100644 --- a/tests/test_ASTModulesImporter.cpp +++ b/tests/test_ASTModulesImporter.cpp @@ -4,11 +4,22 @@ #include <language/ast/ASTModulesImporter.hpp> #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> +#include <language/utils/ASTExecutionInfo.hpp> #include <language/utils/ASTPrinter.hpp> #include <language/utils/SymbolTable.hpp> #include <pegtl/string_input.hpp> +inline void +test_ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_repository) +{ + ASTExecutionInfo execution_info{root_node, module_repository}; + REQUIRE(&root_node == &execution_info.rootNode()); + REQUIRE(&module_repository == &execution_info.moduleRepository()); + + REQUIRE(&ASTExecutionInfo::current() == &execution_info); +} + #define CHECK_AST(data, expected_output) \ { \ static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ @@ -17,10 +28,13 @@ string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ \ - ASTModulesImporter{*ast}; \ + ASTModulesImporter importer{*ast}; \ ASTNodeTypeCleaner<language::import_instruction>{*ast}; \ \ ASTNodeExpressionBuilder{*ast}; \ + const auto& module_repository = importer.moduleRepository(); \ + test_ASTExecutionInfo(*ast, module_repository); \ + \ ExecutionPolicy exec_policy; \ ast->execute(exec_policy); \ \ diff --git a/tests/test_ASTNodeAffectationExpressionBuilder.cpp b/tests/test_ASTNodeAffectationExpressionBuilder.cpp index 895fa860eafe2f6216066d422597e539c1a1cd65..d2aeabb06ed1e69c51e04a9c83842c260b796852 100644 --- a/tests/test_ASTNodeAffectationExpressionBuilder.cpp +++ b/tests/test_ASTNodeAffectationExpressionBuilder.cpp @@ -10,85 +10,96 @@ #include <language/ast/ASTSymbolTableBuilder.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/ASTPrinter.hpp> +#include <language/utils/BasicAffectationRegistrerFor.hpp> +#include <language/utils/EmbeddedData.hpp> #include <language/utils/TypeDescriptor.hpp> #include <utils/Demangle.hpp> #include <utils/Exceptions.hpp> #include <pegtl/string_input.hpp> -#define CHECK_AST(data, expected_output) \ - { \ - static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ - static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or \ - std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>); \ - \ - string_input input{data, "test.pgs"}; \ - auto ast = ASTBuilder::build(input); \ - \ - ASTSymbolTableBuilder{*ast}; \ - ASTNodeDataTypeBuilder{*ast}; \ - \ - ASTNodeDeclarationToAffectationConverter{*ast}; \ - ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ - \ - ASTNodeExpressionBuilder{*ast}; \ - \ - std::stringstream ast_output; \ - ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \ - \ - REQUIRE(ast_output.str() == expected_output); \ +#define CHECK_AST(data, expected_output) \ + { \ + static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ + static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or \ + std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>); \ + \ + string_input input{data, "test.pgs"}; \ + \ + BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ + \ + auto ast = ASTBuilder::build(input); \ + \ + ASTSymbolTableBuilder{*ast}; \ + ASTNodeDataTypeBuilder{*ast}; \ + \ + ASTNodeDeclarationToAffectationConverter{*ast}; \ + ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ + \ + ASTNodeExpressionBuilder{*ast}; \ + \ + std::stringstream ast_output; \ + ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \ + \ + REQUIRE(ast_output.str() == expected_output); \ + \ + OperatorRepository::instance().reset(); \ } template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = {ASTNodeDataType::type_id_t, - "builtin_t"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t"); const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const double>>; -#define CHECK_AST_WITH_BUILTIN(data, expected_output) \ - { \ - static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ - static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or \ - std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>); \ - \ - string_input input{data, "test.pgs"}; \ - auto ast = ASTBuilder::build(input); \ - \ - SymbolTable& symbol_table = *ast->m_symbol_table; \ - auto [i_symbol, success] = symbol_table.add(builtin_data_type.nameOfTypeId(), ast->begin()); \ - if (not success) { \ - throw UnexpectedError("cannot add '" + builtin_data_type.nameOfTypeId() + "' type for testing"); \ - } \ - \ - i_symbol->attributes().setDataType(ASTNodeDataType::type_name_id_t); \ - i_symbol->attributes().setIsInitialized(); \ - i_symbol->attributes().value() = symbol_table.typeEmbedderTable().size(); \ - symbol_table.typeEmbedderTable().add(std::make_shared<TypeDescriptor>(builtin_data_type.nameOfTypeId())); \ - \ - auto [i_symbol_a, success_a] = symbol_table.add("a", ast->begin()); \ - if (not success_a) { \ - throw UnexpectedError("cannot add 'a' of type builtin_t for testing"); \ - } \ - i_symbol_a->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ - i_symbol_a->attributes().setIsInitialized(); \ - auto [i_symbol_b, success_b] = symbol_table.add("b", ast->begin()); \ - if (not success_b) { \ - throw UnexpectedError("cannot add 'b' of type builtin_t for testing"); \ - } \ - i_symbol_b->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ - i_symbol_b->attributes().setIsInitialized(); \ - \ - ASTSymbolTableBuilder{*ast}; \ - ASTNodeDataTypeBuilder{*ast}; \ - \ - ASTNodeDeclarationToAffectationConverter{*ast}; \ - ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ - \ - ASTNodeExpressionBuilder{*ast}; \ - \ - std::stringstream ast_output; \ - ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \ - \ - REQUIRE(ast_output.str() == expected_output); \ +#define CHECK_AST_WITH_BUILTIN(data, expected_output) \ + { \ + static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ + static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or \ + std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>); \ + \ + BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ + \ + string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + SymbolTable& symbol_table = *ast->m_symbol_table; \ + auto [i_symbol, success] = symbol_table.add(builtin_data_type.nameOfTypeId(), ast->begin()); \ + if (not success) { \ + throw UnexpectedError("cannot add '" + builtin_data_type.nameOfTypeId() + "' type for testing"); \ + } \ + \ + i_symbol->attributes().setDataType(ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>()); \ + i_symbol->attributes().setIsInitialized(); \ + i_symbol->attributes().value() = symbol_table.typeEmbedderTable().size(); \ + symbol_table.typeEmbedderTable().add(std::make_shared<TypeDescriptor>(builtin_data_type.nameOfTypeId())); \ + \ + auto [i_symbol_a, success_a] = symbol_table.add("a", ast->begin()); \ + if (not success_a) { \ + throw UnexpectedError("cannot add 'a' of type builtin_t for testing"); \ + } \ + i_symbol_a->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ + i_symbol_a->attributes().setIsInitialized(); \ + auto [i_symbol_b, success_b] = symbol_table.add("b", ast->begin()); \ + if (not success_b) { \ + throw UnexpectedError("cannot add 'b' of type builtin_t for testing"); \ + } \ + i_symbol_b->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ + i_symbol_b->attributes().setIsInitialized(); \ + \ + ASTSymbolTableBuilder{*ast}; \ + ASTNodeDataTypeBuilder{*ast}; \ + \ + ASTNodeDeclarationToAffectationConverter{*ast}; \ + ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ + \ + ASTNodeExpressionBuilder{*ast}; \ + \ + std::stringstream ast_output; \ + ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \ + \ + REQUIRE(ast_output.str() == expected_output); \ + \ + OperatorRepository::instance().reset(); \ } #define CHECK_AST_THROWS_WITH(data, expected_error) \ @@ -109,46 +120,50 @@ const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const dou REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, expected_error); \ } -#define CHECK_AST_WITH_BUILTIN_THROWS_WITH(data, expected_error) \ - { \ - static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ - static_assert(std::is_same_v<std::decay_t<decltype(expected_error)>, std::string_view> or \ - std::is_same_v<std::decay_t<decltype(expected_error)>, std::string>); \ - \ - string_input input{data, "test.pgs"}; \ - auto ast = ASTBuilder::build(input); \ - \ - SymbolTable& symbol_table = *ast->m_symbol_table; \ - auto [i_symbol, success] = symbol_table.add(builtin_data_type.nameOfTypeId(), ast->begin()); \ - if (not success) { \ - throw UnexpectedError("cannot add '" + builtin_data_type.nameOfTypeId() + "' type for testing"); \ - } \ - \ - i_symbol->attributes().setDataType(ASTNodeDataType::type_name_id_t); \ - i_symbol->attributes().setIsInitialized(); \ - i_symbol->attributes().value() = symbol_table.typeEmbedderTable().size(); \ - symbol_table.typeEmbedderTable().add(std::make_shared<TypeDescriptor>(builtin_data_type.nameOfTypeId())); \ - \ - auto [i_symbol_a, success_a] = symbol_table.add("a", ast->begin()); \ - if (not success_a) { \ - throw UnexpectedError("cannot add 'a' of type builtin_t for testing"); \ - } \ - i_symbol_a->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ - i_symbol_a->attributes().setIsInitialized(); \ - auto [i_symbol_b, success_b] = symbol_table.add("b", ast->begin()); \ - if (not success_b) { \ - throw UnexpectedError("cannot add 'b' of type builtin_t for testing"); \ - } \ - i_symbol_b->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ - i_symbol_b->attributes().setIsInitialized(); \ - \ - ASTSymbolTableBuilder{*ast}; \ - ASTNodeDataTypeBuilder{*ast}; \ - \ - ASTNodeDeclarationToAffectationConverter{*ast}; \ - ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ - \ - REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, expected_error); \ +#define CHECK_AST_WITH_BUILTIN_THROWS_WITH(data, expected_error) \ + { \ + static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ + static_assert(std::is_same_v<std::decay_t<decltype(expected_error)>, std::string_view> or \ + std::is_same_v<std::decay_t<decltype(expected_error)>, std::string>); \ + \ + BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ + \ + string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + SymbolTable& symbol_table = *ast->m_symbol_table; \ + auto [i_symbol, success] = symbol_table.add(builtin_data_type.nameOfTypeId(), ast->begin()); \ + if (not success) { \ + throw UnexpectedError("cannot add '" + builtin_data_type.nameOfTypeId() + "' type for testing"); \ + } \ + \ + i_symbol->attributes().setDataType(ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>()); \ + i_symbol->attributes().setIsInitialized(); \ + i_symbol->attributes().value() = symbol_table.typeEmbedderTable().size(); \ + symbol_table.typeEmbedderTable().add(std::make_shared<TypeDescriptor>(builtin_data_type.nameOfTypeId())); \ + \ + auto [i_symbol_a, success_a] = symbol_table.add("a", ast->begin()); \ + if (not success_a) { \ + throw UnexpectedError("cannot add 'a' of type builtin_t for testing"); \ + } \ + i_symbol_a->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ + i_symbol_a->attributes().setIsInitialized(); \ + auto [i_symbol_b, success_b] = symbol_table.add("b", ast->begin()); \ + if (not success_b) { \ + throw UnexpectedError("cannot add 'b' of type builtin_t for testing"); \ + } \ + i_symbol_b->attributes().setDataType(ast_node_data_type_from<std::shared_ptr<const double>>); \ + i_symbol_b->attributes().setIsInitialized(); \ + \ + ASTSymbolTableBuilder{*ast}; \ + ASTNodeDataTypeBuilder{*ast}; \ + \ + ASTNodeDeclarationToAffectationConverter{*ast}; \ + ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ + \ + REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, expected_error); \ + \ + OperatorRepository::instance().reset(); \ } // clazy:excludeall=non-pod-global-static @@ -672,7 +687,7 @@ let t : (B), t = (true, false); std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, bool>) + `-(language::eq_op:AffectationToTupleFromListProcessor<bool>) +-(language::name:t:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) +-(language::true_kw:ValueProcessor) @@ -690,7 +705,7 @@ let t : (N), t = (1, 2, 3, 5); std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, unsigned long>) + `-(language::eq_op:AffectationToTupleFromListProcessor<unsigned long>) +-(language::name:t:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) +-(language::integer:1:ValueProcessor) @@ -714,7 +729,7 @@ let t : (Z), t = (2, n, true); +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>) | +-(language::name:n:NameProcessor) | `-(language::integer:3:ValueProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, long>) + `-(language::eq_op:AffectationToTupleFromListProcessor<long>) +-(language::name:t:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) +-(language::integer:2:ValueProcessor) @@ -733,7 +748,7 @@ let t : (R), t = (2, 3.1, 5); std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, double>) + `-(language::eq_op:AffectationToTupleFromListProcessor<double>) +-(language::name:t:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) +-(language::integer:2:ValueProcessor) @@ -760,7 +775,7 @@ let t3 : (R^1), t3 = (1, 2.3, 0); | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::integer:2:ValueProcessor) | `-(language::real:3.1:ValueProcessor) - +-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, TinyVector<2ul, double> >) + +-(language::eq_op:AffectationToTupleFromListProcessor<TinyVector<2ul, double> >) | +-(language::name:t1:NameProcessor) | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::name:a:NameProcessor) @@ -768,12 +783,12 @@ let t3 : (R^1), t3 = (1, 2.3, 0); | | +-(language::integer:1:ValueProcessor) | | `-(language::integer:2:ValueProcessor) | `-(language::integer:0:ValueProcessor) - +-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, TinyVector<3ul, double> >) + +-(language::eq_op:AffectationToTupleFromListProcessor<TinyVector<3ul, double> >) | +-(language::name:t2:NameProcessor) | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::integer:0:ValueProcessor) | `-(language::integer:0:ValueProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, TinyVector<1ul, double> >) + `-(language::eq_op:AffectationToTupleFromListProcessor<TinyVector<1ul, double> >) +-(language::name:t3:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) +-(language::integer:1:ValueProcessor) @@ -792,7 +807,7 @@ let t : (string), t = ("foo", "bar"); std::string result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, )" + + `-(language::eq_op:AffectationToTupleFromListProcessor<)" + demangled_stdstring + R"( >) +-(language::name:t:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) @@ -811,7 +826,7 @@ let t : (builtin_t), t= (a,b,a); std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleFromListProcessor<language::eq_op, EmbeddedData>) + `-(language::eq_op:AffectationToTupleFromListProcessor<EmbeddedData>) +-(language::name:t:NameProcessor) `-(language::expression_list:ASTNodeExpressionListProcessor) +-(language::name:a:NameProcessor) @@ -833,7 +848,7 @@ let t : (B), t = true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, bool>) + `-(language::eq_op:AffectationToTupleProcessor<bool>) +-(language::name:t:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -849,7 +864,7 @@ let t : (N), t = 1; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, unsigned long>) + `-(language::eq_op:AffectationToTupleProcessor<unsigned long>) +-(language::name:t:NameProcessor) `-(language::integer:1:ValueProcessor) )"; @@ -869,7 +884,7 @@ let t : (Z), t = n; +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>) | +-(language::name:n:NameProcessor) | `-(language::integer:3:ValueProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, long>) + `-(language::eq_op:AffectationToTupleProcessor<long>) +-(language::name:t:NameProcessor) `-(language::name:n:NameProcessor) )"; @@ -885,7 +900,7 @@ let t : (R), t = 3.1; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, double>) + `-(language::eq_op:AffectationToTupleProcessor<double>) +-(language::name:t:NameProcessor) `-(language::real:3.1:ValueProcessor) )"; @@ -909,13 +924,13 @@ let t3 : (R^1), t3 = 2.3; | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::integer:2:ValueProcessor) | `-(language::real:3.1:ValueProcessor) - +-(language::eq_op:AffectationToTupleProcessor<language::eq_op, TinyVector<2ul, double> >) + +-(language::eq_op:AffectationToTupleProcessor<TinyVector<2ul, double> >) | +-(language::name:t1:NameProcessor) | `-(language::name:a:NameProcessor) - +-(language::eq_op:AffectationToTupleProcessor<language::eq_op, TinyVector<3ul, double> >) + +-(language::eq_op:AffectationToTupleProcessor<TinyVector<3ul, double> >) | +-(language::name:t2:NameProcessor) | `-(language::integer:0:ValueProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, TinyVector<1ul, double> >) + `-(language::eq_op:AffectationToTupleProcessor<TinyVector<1ul, double> >) +-(language::name:t3:NameProcessor) `-(language::real:2.3:ValueProcessor) )"; @@ -931,7 +946,7 @@ let t : (string), t = "foo"; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, )" + + `-(language::eq_op:AffectationToTupleProcessor<)" + demangled_stdstring + R"( >) +-(language::name:t:NameProcessor) `-(language::literal:"foo":ValueProcessor) @@ -948,7 +963,7 @@ let t : (builtin_t), t = a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eq_op:AffectationToTupleProcessor<language::eq_op, EmbeddedData>) + `-(language::eq_op:AffectationToTupleProcessor<EmbeddedData>) +-(language::name:t:NameProcessor) `-(language::name:a:NameProcessor) )"; @@ -1511,43 +1526,29 @@ let x : R, x=1; x/=2.3; { SECTION("Invalid affectation operator") { - auto ast = std::make_unique<ASTNode>(); + auto ast = std::make_unique<ASTNode>(); + ast->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + { + auto child_0 = std::make_unique<ASTNode>(); + child_0->m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + auto child_1 = std::make_unique<ASTNode>(); + child_1->m_data_type = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + ast->children.emplace_back(std::move(child_0)); + ast->children.emplace_back(std::move(child_1)); + } REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, "unexpected error: undefined affectation operator"); } - SECTION("Invalid lhs") - { - auto ast = std::make_unique<ASTNode>(); - ast->set_type<language::eq_op>(); - ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, - "unexpected error: undefined value type for affectation"); - } - - SECTION("Invalid rhs") - { - auto ast = std::make_unique<ASTNode>(); - ast->set_type<language::eq_op>(); - ast->m_data_type = ASTNodeDataType::int_t; - - ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, - "unexpected error: invalid implicit conversion: undefined -> Z"); - } - SECTION("Invalid string rhs") { auto ast = std::make_unique<ASTNode>(); ast->set_type<language::eq_op>(); - ast->m_data_type = ASTNodeDataType::string_t; ast->children.emplace_back(std::make_unique<ASTNode>()); + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, - "unexpected error: invalid implicit conversion: undefined -> string"); + REQUIRE_THROWS_WITH(ASTNodeAffectationExpressionBuilder{*ast}, "undefined affectation type: string = undefined"); } SECTION("Invalid string affectation operator") @@ -1558,7 +1559,7 @@ let x : R, x=1; x/=2.3; let s : string, s="foo"; s-="bar"; )"; - std::string error_message = "invalid affectation operator for string"; + std::string error_message = "undefined affectation type: string -= string"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1569,7 +1570,7 @@ let s : string, s="foo"; s-="bar"; let s : string, s="foo"; s*=2; )"; - std::string error_message = "invalid affectation operator for string"; + std::string error_message = "undefined affectation type: string *= Z"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1580,7 +1581,7 @@ let s : string, s="foo"; s*=2; let s : string, s="foo"; s/="bar"; )"; - std::string error_message = "invalid affectation operator for string"; + std::string error_message = "undefined affectation type: string /= string"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1592,7 +1593,7 @@ let s : string, s="foo"; s*=2; let s :builtin_t, s = a; s *= b; )"; - std::string error_message = "invalid affectation operator for 'builtin_t'"; + std::string error_message = "undefined affectation type: builtin_t *= builtin_t"; CHECK_AST_WITH_BUILTIN_THROWS_WITH(data, error_message); } @@ -1603,7 +1604,7 @@ let s : string, s="foo"; s*=2; let s :(R), s=(1,2,3); s *= 4; )"; - std::string error_message = "invalid affectation operator for 'tuple(R)'"; + std::string error_message = "undefined affectation type: tuple(R) *= Z"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1614,7 +1615,7 @@ let s : string, s="foo"; s*=2; let s : (builtin_t), s =(a,b); s *= b; )"; - std::string error_message = "invalid affectation operator for 'tuple(builtin_t)'"; + std::string error_message = "undefined affectation type: tuple(builtin_t) *= builtin_t"; CHECK_AST_WITH_BUILTIN_THROWS_WITH(data, error_message); } @@ -1627,7 +1628,7 @@ let s : string, s="foo"; s*=2; let x : R^3; let y : R^1; x = y; )"; - std::string error_message = "incompatible dimensions in affectation"; + std::string error_message = "undefined affectation type: R^3 = R^1"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1638,7 +1639,7 @@ let x : R^3; let y : R^1; x = y; let x : R^3; let y : R^2; x = y; )"; - std::string error_message = "incompatible dimensions in affectation"; + std::string error_message = "undefined affectation type: R^3 = R^2"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1649,7 +1650,7 @@ let x : R^3; let y : R^2; x = y; let x : R^2; let y : R^1; x = y; )"; - std::string error_message = "incompatible dimensions in affectation"; + std::string error_message = "undefined affectation type: R^2 = R^1"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1660,7 +1661,7 @@ let x : R^2; let y : R^1; x = y; let x : R^2; let y : R^3; x = y; )"; - std::string error_message = "incompatible dimensions in affectation"; + std::string error_message = "undefined affectation type: R^2 = R^3"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1671,7 +1672,7 @@ let x : R^2; let y : R^3; x = y; let x : R^1; let y : R^2; x = y; )"; - std::string error_message = "incompatible dimensions in affectation"; + std::string error_message = "undefined affectation type: R^1 = R^2"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1679,10 +1680,10 @@ let x : R^1; let y : R^2; x = y; SECTION("R^1 <- R^3") { std::string_view data = R"( -let x : R^1; let y : R^2; x = y; +let x : R^1; let y : R^3; x = y; )"; - std::string error_message = "incompatible dimensions in affectation"; + std::string error_message = "undefined affectation type: R^1 = R^3"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1696,7 +1697,7 @@ let x : R^1; let y : R^2; x = y; let x : R^3, x = 3; )"; - std::string error_message = "invalid implicit conversion: Z -> R^3"; + std::string error_message = "invalid integral value (0 is the solely valid value)"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1707,7 +1708,7 @@ let x : R^3, x = 3; let x : R^2, x = 2; )"; - std::string error_message = "invalid implicit conversion: Z -> R^2"; + std::string error_message = "invalid integral value (0 is the solely valid value)"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1721,7 +1722,7 @@ let x : R^2, x = 2; let x : R^3; let y : R^3; x /= y; )"; - std::string error_message = "invalid affectation operator for R^3"; + std::string error_message = "undefined affectation type: R^3 /= R^3"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1732,7 +1733,7 @@ let x : R^3; let y : R^3; x /= y; let x : R^2; let y : R^2; x /= y; )"; - std::string error_message = "invalid affectation operator for R^2"; + std::string error_message = "undefined affectation type: R^2 /= R^2"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1743,7 +1744,7 @@ let x : R^2; let y : R^2; x /= y; let x : R^1; let y : R^1; x /= y; )"; - std::string error_message = "invalid affectation operator for R^1"; + std::string error_message = "undefined affectation type: R^1 /= R^1"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1757,7 +1758,7 @@ let x : R^1; let y : R^1; x /= y; let x : R^3; let y : R^3; x *= y; )"; - std::string error_message = "expecting scalar operand type"; + std::string error_message = "undefined affectation type: R^3 *= R^3"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1768,7 +1769,7 @@ let x : R^3; let y : R^3; x *= y; let x : R^2; let y : R^2; x *= y; )"; - std::string error_message = "expecting scalar operand type"; + std::string error_message = "undefined affectation type: R^2 *= R^2"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1779,7 +1780,7 @@ let x : R^2; let y : R^2; x *= y; let x : R^1; let y : R^1; x *= y; )"; - std::string error_message = "expecting scalar operand type"; + std::string error_message = "undefined affectation type: R^1 *= R^1"; CHECK_AST_THROWS_WITH(data, error_message); } @@ -1799,20 +1800,6 @@ let (x,y,z):R*R*R, (x,y) = (2,3); std::string{"invalid number of definition identifiers, expecting 3 found 2"}); } - SECTION("incorrect identifier/expression number of symbols") - { - std::string_view data = R"( -let (x,y,z):R*R*R, (x,y,z) = (2,3); -)"; - - string_input input{data, "test.pgs"}; - auto ast = ASTBuilder::build(input); - - ASTSymbolTableBuilder{*ast}; - REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, - std::string{"invalid number of definition expressions, expecting 3 found 2"}); - } - SECTION("incorrect identifier/expression number of symbols") { std::string_view data = R"( diff --git a/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp b/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp index 941df56ac9d28de2aa5a211c61964b8d9ce05ee6..7eadd2067e54e670021aa2bee03e92fe0abedccd 100644 --- a/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp +++ b/tests/test_ASTNodeArraySubscriptExpressionBuilder.cpp @@ -18,7 +18,7 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") { { std::unique_ptr array_node = std::make_unique<ASTNode>(); - array_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); node->emplace_back(std::move(array_node)); } REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); @@ -31,7 +31,7 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") { { std::unique_ptr array_node = std::make_unique<ASTNode>(); - array_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); node->emplace_back(std::move(array_node)); } REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); @@ -44,7 +44,7 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") { { std::unique_ptr array_node = std::make_unique<ASTNode>(); - array_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); node->emplace_back(std::move(array_node)); } REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); @@ -52,6 +52,45 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") auto& node_processor = *node->m_node_processor; REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyVector<3>>).name()); } + + SECTION("R^1x1") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + node->emplace_back(std::move(array_node)); + } + REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); + REQUIRE(bool{node->m_node_processor}); + auto& node_processor = *node->m_node_processor; + REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyMatrix<1>>).name()); + } + + SECTION("R^2x2") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + node->emplace_back(std::move(array_node)); + } + REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); + REQUIRE(bool{node->m_node_processor}); + auto& node_processor = *node->m_node_processor; + REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyMatrix<2>>).name()); + } + + SECTION("R^3x3") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + node->emplace_back(std::move(array_node)); + } + REQUIRE_NOTHROW(ASTNodeArraySubscriptExpressionBuilder{*node}); + REQUIRE(bool{node->m_node_processor}); + auto& node_processor = *node->m_node_processor; + REQUIRE(typeid(node_processor).name() == typeid(ArraySubscriptProcessor<TinyMatrix<3>>).name()); + } } SECTION("R^d component bad access") @@ -62,16 +101,37 @@ TEST_CASE("ASTNodeArraySubscriptExpressionBuilder", "[language]") { { std::unique_ptr array_node = std::make_unique<ASTNode>(); - array_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 0}; + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(0); node->emplace_back(std::move(array_node)); } REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); } + SECTION("R^d (d > 3)") { { std::unique_ptr array_node = std::make_unique<ASTNode>(); - array_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 4}; + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(4); + node->emplace_back(std::move(array_node)); + } + REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); + } + + SECTION("R^dxd (d < 1)") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(0, 0); + node->emplace_back(std::move(array_node)); + } + REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); + } + + SECTION("R^dxd (d > 3)") + { + { + std::unique_ptr array_node = std::make_unique<ASTNode>(); + array_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(4, 4); node->emplace_back(std::move(array_node)); } REQUIRE_THROWS_WITH(ASTNodeArraySubscriptExpressionBuilder{*node}, "unexpected error: invalid array dimension"); diff --git a/tests/test_ASTNodeBinaryOperatorExpressionBuilder.cpp b/tests/test_ASTNodeBinaryOperatorExpressionBuilder.cpp index 78a3f5f3b16bfdcb804421cc41e59a1c37abde4f..1c030e91791827d6a4b69a087c116b2cdd61fc33 100644 --- a/tests/test_ASTNodeBinaryOperatorExpressionBuilder.cpp +++ b/tests/test_ASTNodeBinaryOperatorExpressionBuilder.cpp @@ -35,6 +35,18 @@ REQUIRE(ast_output.str() == expected_output); \ } +#define REQUIRE_AST_THROWS_WITH(data, expected_output) \ + { \ + static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ + \ + string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + ASTSymbolTableBuilder{*ast}; \ + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, expected_output); \ + } + +/* #define REQUIRE_AST_THROWS_WITH(data, expected_output) \ { \ static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ @@ -50,6 +62,7 @@ \ REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, expected_output); \ } +*/ // clazy:excludeall=non-pod-global-static @@ -67,11 +80,11 @@ false*b*true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, bool, bool>) + +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, bool>) - +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, bool, bool>) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, unsigned long, bool>) + +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -90,8 +103,8 @@ n*m*n; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, unsigned long>) - +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, unsigned long>) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, unsigned long, unsigned long>) + +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, unsigned long, unsigned long, unsigned long>) | +-(language::name:n:NameProcessor) | `-(language::name:m:NameProcessor) `-(language::name:n:NameProcessor) @@ -109,8 +122,8 @@ a*3*a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, long>) - +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, long>) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, long, long>) + +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, long, long>) | +-(language::name:a:NameProcessor) | `-(language::integer:3:ValueProcessor) `-(language::name:a:NameProcessor) @@ -127,9 +140,9 @@ a*3*a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, double, bool>) - +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, double, long>) - | +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, double, double>) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, double, double, bool>) + +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, double, double, long>) + | +-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, double, double, double>) | | +-(language::real:2.3:ValueProcessor) | | `-(language::real:1.2:ValueProcessor) | `-(language::integer:2:ValueProcessor) @@ -151,7 +164,7 @@ let x : R^1, x = 3.7; +-(language::eq_op:AffectationProcessor<language::eq_op, TinyVector<1ul, double>, double>) | +-(language::name:x:NameProcessor) | `-(language::real:3.7:ValueProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, TinyVector<1ul, double> >) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, TinyVector<1ul, double>, long, TinyVector<1ul, double> >) +-(language::integer:2:ValueProcessor) `-(language::name:x:NameProcessor) )"; @@ -173,7 +186,7 @@ let x : R^2, x = (3.2,6); | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::real:3.2:ValueProcessor) | `-(language::integer:6:ValueProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, TinyVector<2ul, double> >) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, TinyVector<2ul, double>, long, TinyVector<2ul, double> >) +-(language::integer:2:ValueProcessor) `-(language::name:x:NameProcessor) )"; @@ -196,7 +209,7 @@ let x : R^3, x = (3.2,6,1.2); | +-(language::real:3.2:ValueProcessor) | +-(language::integer:6:ValueProcessor) | `-(language::real:1.2:ValueProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, TinyVector<3ul, double> >) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, TinyVector<3ul, double>, long, TinyVector<3ul, double> >) +-(language::integer:2:ValueProcessor) `-(language::name:x:NameProcessor) )"; @@ -217,11 +230,11 @@ false/b/true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, bool, bool>) + +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, bool>) - +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, bool, bool>) + `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, unsigned long, bool>) + +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -240,8 +253,8 @@ n/m/n; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, unsigned long>) - +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, unsigned long>) + `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, unsigned long, unsigned long>) + +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, unsigned long, unsigned long, unsigned long>) | +-(language::name:n:NameProcessor) | `-(language::name:m:NameProcessor) `-(language::name:n:NameProcessor) @@ -259,8 +272,8 @@ a/3/a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, long>) - +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, long>) + `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, long, long>) + +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, long, long>) | +-(language::name:a:NameProcessor) | `-(language::integer:3:ValueProcessor) `-(language::name:a:NameProcessor) @@ -277,9 +290,9 @@ a/3/a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, double, bool>) - +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, double, long>) - | +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, double, double>) + `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, double, double, bool>) + +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, double, double, long>) + | +-(language::divide_op:BinaryExpressionProcessor<language::divide_op, double, double, double>) | | +-(language::real:2.3:ValueProcessor) | | `-(language::real:1.2:ValueProcessor) | `-(language::integer:2:ValueProcessor) @@ -302,11 +315,11 @@ false+b+true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, bool, bool>) + +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, bool>) - +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, bool, bool>) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, unsigned long, bool>) + +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -325,8 +338,8 @@ n+m+n; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, unsigned long>) - +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, unsigned long>) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, unsigned long, unsigned long>) + +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, unsigned long, unsigned long, unsigned long>) | +-(language::name:n:NameProcessor) | `-(language::name:m:NameProcessor) `-(language::name:n:NameProcessor) @@ -344,8 +357,8 @@ a+3+a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, long>) - +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, long>) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, long, long>) + +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, long, long>) | +-(language::name:a:NameProcessor) | `-(language::integer:3:ValueProcessor) `-(language::name:a:NameProcessor) @@ -362,9 +375,9 @@ a+3+a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, double, bool>) - +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, double, long>) - | +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, double, double>) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, double, double, bool>) + +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, double, double, long>) + | +-(language::plus_op:BinaryExpressionProcessor<language::plus_op, double, double, double>) | | +-(language::real:2.3:ValueProcessor) | | `-(language::real:1.2:ValueProcessor) | `-(language::integer:2:ValueProcessor) @@ -390,7 +403,7 @@ x+y; +-(language::eq_op:AffectationProcessor<language::eq_op, TinyVector<1ul, double>, long>) | +-(language::name:y:NameProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, TinyVector<1ul, double>, TinyVector<1ul, double> >) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, TinyVector<1ul, double>, TinyVector<1ul, double>, TinyVector<1ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -418,7 +431,7 @@ x+y; | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::real:0.3:ValueProcessor) | `-(language::real:0.7:ValueProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, TinyVector<2ul, double>, TinyVector<2ul, double> >) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, TinyVector<2ul, double>, TinyVector<2ul, double>, TinyVector<2ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -448,7 +461,7 @@ x+y; | +-(language::integer:4:ValueProcessor) | +-(language::integer:3:ValueProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, TinyVector<3ul, double>, TinyVector<3ul, double> >) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, TinyVector<3ul, double>, TinyVector<3ul, double>, TinyVector<3ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -556,11 +569,11 @@ false-b-true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, bool, bool>) + +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, bool>) - +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, bool, bool>) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long, bool>) + +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -579,8 +592,8 @@ n-m-n; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, unsigned long, unsigned long>) - +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, unsigned long, unsigned long>) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long, unsigned long>) + +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, unsigned long, unsigned long>) | +-(language::name:n:NameProcessor) | `-(language::name:m:NameProcessor) `-(language::name:n:NameProcessor) @@ -598,8 +611,8 @@ a-3-a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long>) - +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long>) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long, long>) + +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long, long>) | +-(language::name:a:NameProcessor) | `-(language::integer:3:ValueProcessor) `-(language::name:a:NameProcessor) @@ -616,9 +629,9 @@ a-3-a; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, double, bool>) - +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, double, long>) - | +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, double, double>) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, double, double, bool>) + +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, double, double, long>) + | +-(language::minus_op:BinaryExpressionProcessor<language::minus_op, double, double, double>) | | +-(language::real:2.3:ValueProcessor) | | `-(language::real:1.2:ValueProcessor) | `-(language::integer:2:ValueProcessor) @@ -644,7 +657,7 @@ x-y; +-(language::eq_op:AffectationProcessor<language::eq_op, TinyVector<1ul, double>, long>) | +-(language::name:y:NameProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, TinyVector<1ul, double>, TinyVector<1ul, double> >) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, TinyVector<1ul, double>, TinyVector<1ul, double>, TinyVector<1ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -672,7 +685,7 @@ x-y; | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::real:0.3:ValueProcessor) | `-(language::real:0.7:ValueProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, TinyVector<2ul, double>, TinyVector<2ul, double> >) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, TinyVector<2ul, double>, TinyVector<2ul, double>, TinyVector<2ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -702,7 +715,7 @@ x-y; | +-(language::integer:4:ValueProcessor) | +-(language::integer:3:ValueProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, TinyVector<3ul, double>, TinyVector<3ul, double> >) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, TinyVector<3ul, double>, TinyVector<3ul, double>, TinyVector<3ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -723,11 +736,11 @@ false or b or true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::or_op:BinaryExpressionProcessor<language::or_op, bool, bool>) + +-(language::or_op:BinaryExpressionProcessor<language::or_op, bool, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::or_op:BinaryExpressionProcessor<language::or_op, bool, bool>) - +-(language::or_op:BinaryExpressionProcessor<language::or_op, bool, bool>) + `-(language::or_op:BinaryExpressionProcessor<language::or_op, bool, bool, bool>) + +-(language::or_op:BinaryExpressionProcessor<language::or_op, bool, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -749,11 +762,11 @@ false and b and true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::and_op:BinaryExpressionProcessor<language::and_op, bool, bool>) + +-(language::and_op:BinaryExpressionProcessor<language::and_op, bool, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::and_op:BinaryExpressionProcessor<language::and_op, bool, bool>) - +-(language::and_op:BinaryExpressionProcessor<language::and_op, bool, bool>) + `-(language::and_op:BinaryExpressionProcessor<language::and_op, bool, bool, bool>) + +-(language::and_op:BinaryExpressionProcessor<language::and_op, bool, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -775,11 +788,11 @@ false xor b xor true; std::string_view result = R"( (root:ASTNodeListProcessor) - +-(language::xor_op:BinaryExpressionProcessor<language::xor_op, bool, bool>) + +-(language::xor_op:BinaryExpressionProcessor<language::xor_op, bool, bool, bool>) | +-(language::name:b:NameProcessor) | `-(language::true_kw:ValueProcessor) - `-(language::xor_op:BinaryExpressionProcessor<language::xor_op, bool, bool>) - +-(language::xor_op:BinaryExpressionProcessor<language::xor_op, bool, bool>) + `-(language::xor_op:BinaryExpressionProcessor<language::xor_op, bool, bool, bool>) + +-(language::xor_op:BinaryExpressionProcessor<language::xor_op, bool, bool, bool>) | +-(language::false_kw:ValueProcessor) | `-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) @@ -800,7 +813,7 @@ b > true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, bool, bool>) + `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, bool, bool, bool>) +-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -818,7 +831,7 @@ n > m; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, unsigned long, unsigned long>) + `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, bool, unsigned long, unsigned long>) +-(language::name:n:NameProcessor) `-(language::name:m:NameProcessor) )"; @@ -835,7 +848,7 @@ a > 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, long, long>) + `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, bool, long, long>) +-(language::name:a:NameProcessor) `-(language::integer:3:ValueProcessor) )"; @@ -851,7 +864,7 @@ a > 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, double, double>) + `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, bool, double, double>) +-(language::real:2.3:ValueProcessor) `-(language::real:1.2:ValueProcessor) )"; @@ -871,7 +884,7 @@ b < true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, bool>) + `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, bool, bool>) +-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -889,7 +902,7 @@ n < m; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, unsigned long, unsigned long>) + `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, unsigned long, unsigned long>) +-(language::name:n:NameProcessor) `-(language::name:m:NameProcessor) )"; @@ -906,7 +919,7 @@ a < 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, long, long>) + `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, long, long>) +-(language::name:a:NameProcessor) `-(language::integer:3:ValueProcessor) )"; @@ -922,7 +935,7 @@ a < 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, double, double>) + `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, double, double>) +-(language::real:2.3:ValueProcessor) `-(language::real:1.2:ValueProcessor) )"; @@ -941,7 +954,7 @@ b >= true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, bool, bool>) + `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, bool, bool, bool>) +-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -959,7 +972,7 @@ n >= m; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, unsigned long, unsigned long>) + `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, bool, unsigned long, unsigned long>) +-(language::name:n:NameProcessor) `-(language::name:m:NameProcessor) )"; @@ -976,7 +989,7 @@ a >= 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, long, long>) + `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, bool, long, long>) +-(language::name:a:NameProcessor) `-(language::integer:3:ValueProcessor) )"; @@ -992,7 +1005,7 @@ a >= 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, double, double>) + `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, bool, double, double>) +-(language::real:2.3:ValueProcessor) `-(language::real:1.2:ValueProcessor) )"; @@ -1012,7 +1025,7 @@ b <= true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, bool, bool>) + `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, bool, bool, bool>) +-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -1030,7 +1043,7 @@ n <= m; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, unsigned long, unsigned long>) + `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, bool, unsigned long, unsigned long>) +-(language::name:n:NameProcessor) `-(language::name:m:NameProcessor) )"; @@ -1047,7 +1060,7 @@ a <= 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, long, long>) + `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, bool, long, long>) +-(language::name:a:NameProcessor) `-(language::integer:3:ValueProcessor) )"; @@ -1063,7 +1076,7 @@ a <= 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, double, double>) + `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, bool, double, double>) +-(language::real:2.3:ValueProcessor) `-(language::real:1.2:ValueProcessor) )"; @@ -1083,7 +1096,7 @@ b == true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, bool>) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, bool, bool>) +-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -1101,7 +1114,7 @@ n == m; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, unsigned long, unsigned long>) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, unsigned long, unsigned long>) +-(language::name:n:NameProcessor) `-(language::name:m:NameProcessor) )"; @@ -1118,7 +1131,7 @@ a == 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, long, long>) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, long, long>) +-(language::name:a:NameProcessor) `-(language::integer:3:ValueProcessor) )"; @@ -1134,7 +1147,7 @@ a == 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, double, double>) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, double, double>) +-(language::real:2.3:ValueProcessor) `-(language::real:1.2:ValueProcessor) )"; @@ -1152,7 +1165,7 @@ a == 3; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, )" + + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, )" + string_name + ", " + string_name + R"( >) +-(language::literal:"foo":ValueProcessor) `-(language::literal:"bar":ValueProcessor) @@ -1177,7 +1190,7 @@ x==y; +-(language::eq_op:AffectationProcessor<language::eq_op, TinyVector<1ul, double>, long>) | +-(language::name:y:NameProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, TinyVector<1ul, double>, TinyVector<1ul, double> >) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, TinyVector<1ul, double>, TinyVector<1ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -1205,7 +1218,7 @@ x==y; | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::real:0.3:ValueProcessor) | `-(language::real:0.7:ValueProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, TinyVector<2ul, double>, TinyVector<2ul, double> >) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, TinyVector<2ul, double>, TinyVector<2ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -1235,7 +1248,7 @@ x==y; | +-(language::integer:4:ValueProcessor) | +-(language::integer:3:ValueProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, TinyVector<3ul, double>, TinyVector<3ul, double> >) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, TinyVector<3ul, double>, TinyVector<3ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -1255,7 +1268,7 @@ b != true; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, bool>) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, bool, bool>) +-(language::name:b:NameProcessor) `-(language::true_kw:ValueProcessor) )"; @@ -1273,7 +1286,7 @@ n != m; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, unsigned long, unsigned long>) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, unsigned long, unsigned long>) +-(language::name:n:NameProcessor) `-(language::name:m:NameProcessor) )"; @@ -1290,7 +1303,7 @@ a != 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, long, long>) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, long, long>) +-(language::name:a:NameProcessor) `-(language::integer:3:ValueProcessor) )"; @@ -1306,7 +1319,7 @@ a != 3; std::string_view result = R"( (root:ASTNodeListProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, double, double>) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, double, double>) +-(language::real:2.3:ValueProcessor) `-(language::real:1.2:ValueProcessor) )"; @@ -1330,7 +1343,7 @@ x!=y; +-(language::eq_op:AffectationProcessor<language::eq_op, TinyVector<1ul, double>, long>) | +-(language::name:y:NameProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, TinyVector<1ul, double>, TinyVector<1ul, double> >) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, TinyVector<1ul, double>, TinyVector<1ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -1358,7 +1371,7 @@ x!=y; | `-(language::expression_list:ASTNodeExpressionListProcessor) | +-(language::real:0.3:ValueProcessor) | `-(language::real:0.7:ValueProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, TinyVector<2ul, double>, TinyVector<2ul, double> >) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, TinyVector<2ul, double>, TinyVector<2ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -1388,7 +1401,7 @@ x!=y; | +-(language::integer:4:ValueProcessor) | +-(language::integer:3:ValueProcessor) | `-(language::integer:2:ValueProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, TinyVector<3ul, double>, TinyVector<3ul, double> >) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, TinyVector<3ul, double>, TinyVector<3ul, double> >) +-(language::name:x:NameProcessor) `-(language::name:y:NameProcessor) )"; @@ -1406,7 +1419,7 @@ x!=y; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, )" + + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, )" + string_name + ", " + string_name + R"( >) +-(language::literal:"foo":ValueProcessor) `-(language::literal:"bar":ValueProcessor) @@ -1422,6 +1435,11 @@ x!=y; SECTION("Invalid binary operator type") { auto ast = std::make_unique<ASTNode>(); + ast->set_type<language::ignored>(); + ast->children.emplace_back(std::make_unique<ASTNode>()); + ast->children.emplace_back(std::make_unique<ASTNode>()); + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "unexpected error: undefined binary operator"); } @@ -1434,10 +1452,10 @@ x!=y; ast->set_type<language::multiply_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void * Z"); } SECTION("left string multiply") @@ -1446,10 +1464,10 @@ x!=y; ast->set_type<language::multiply_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: string * Z"); } SECTION("right string multiply") @@ -1458,10 +1476,10 @@ x!=y; ast->set_type<language::multiply_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: Z * string"); } SECTION("lhs bad divide") @@ -1470,10 +1488,10 @@ x!=y; ast->set_type<language::divide_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void / Z"); } SECTION("left string divide") @@ -1482,10 +1500,10 @@ x!=y; ast->set_type<language::divide_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: string / Z"); } SECTION("right string divide") @@ -1494,10 +1512,10 @@ x!=y; ast->set_type<language::divide_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: Z / string"); } SECTION("lhs bad plus") @@ -1506,10 +1524,10 @@ x!=y; ast->set_type<language::plus_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void + Z"); } SECTION("left string plus bad rhs") @@ -1518,10 +1536,11 @@ x!=y; ast->set_type<language::plus_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::void_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string + void"); } SECTION("right string plus") @@ -1530,10 +1549,10 @@ x!=y; ast->set_type<language::plus_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: Z + string"); } SECTION("lhs bad minus") @@ -1542,10 +1561,10 @@ x!=y; ast->set_type<language::minus_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void - Z"); } SECTION("left string minus") @@ -1554,10 +1573,10 @@ x!=y; ast->set_type<language::minus_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: string - Z"); } SECTION("right string minus") @@ -1566,10 +1585,10 @@ x!=y; ast->set_type<language::minus_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: Z - string"); } SECTION("lhs bad or") @@ -1578,10 +1597,10 @@ x!=y; ast->set_type<language::or_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void or Z"); } SECTION("left string or") @@ -1590,10 +1609,11 @@ x!=y; ast->set_type<language::or_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string or Z"); } SECTION("right string or") @@ -1602,10 +1622,11 @@ x!=y; ast->set_type<language::or_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z or string"); } SECTION("lhs bad and") @@ -1614,10 +1635,10 @@ x!=y; ast->set_type<language::and_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void and Z"); } SECTION("left string and") @@ -1626,10 +1647,11 @@ x!=y; ast->set_type<language::and_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string and Z"); } SECTION("right string and") @@ -1638,10 +1660,11 @@ x!=y; ast->set_type<language::and_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z and string"); } SECTION("lhs bad xor") @@ -1650,10 +1673,10 @@ x!=y; ast->set_type<language::xor_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void xor Z"); } SECTION("left string xor") @@ -1662,10 +1685,11 @@ x!=y; ast->set_type<language::xor_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string xor Z"); } SECTION("right string xor") @@ -1674,10 +1698,11 @@ x!=y; ast->set_type<language::xor_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z xor string"); } SECTION("lhs bad >") @@ -1686,10 +1711,10 @@ x!=y; ast->set_type<language::greater_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void > Z"); } SECTION("left string >") @@ -1698,10 +1723,10 @@ x!=y; ast->set_type<language::greater_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: string > Z"); } SECTION("right string >") @@ -1710,10 +1735,10 @@ x!=y; ast->set_type<language::greater_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: Z > string"); } SECTION("lhs bad <") @@ -1722,10 +1747,10 @@ x!=y; ast->set_type<language::lesser_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void < Z"); } SECTION("left string <") @@ -1734,10 +1759,10 @@ x!=y; ast->set_type<language::lesser_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: string < Z"); } SECTION("right string <") @@ -1746,10 +1771,10 @@ x!=y; ast->set_type<language::lesser_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: Z < string"); } SECTION("lhs bad >=") @@ -1758,10 +1783,10 @@ x!=y; ast->set_type<language::greater_or_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void >= Z"); } SECTION("left string >=") @@ -1770,10 +1795,11 @@ x!=y; ast->set_type<language::greater_or_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string >= Z"); } SECTION("right string >=") @@ -1782,10 +1808,11 @@ x!=y; ast->set_type<language::greater_or_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z >= string"); } SECTION("lhs bad <=") @@ -1794,10 +1821,10 @@ x!=y; ast->set_type<language::lesser_or_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void <= Z"); } SECTION("left string <=") @@ -1806,10 +1833,11 @@ x!=y; ast->set_type<language::lesser_or_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string <= Z"); } SECTION("right string <=") @@ -1818,10 +1846,11 @@ x!=y; ast->set_type<language::lesser_or_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z <= string"); } SECTION("lhs bad ==") @@ -1830,10 +1859,10 @@ x!=y; ast->set_type<language::eqeq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void == Z"); } SECTION("left string ==") @@ -1842,10 +1871,11 @@ x!=y; ast->set_type<language::eqeq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string == Z"); } SECTION("right string ==") @@ -1854,10 +1884,11 @@ x!=y; ast->set_type<language::eqeq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z == string"); } SECTION("lhs bad !=") @@ -1866,34 +1897,36 @@ x!=y; ast->set_type<language::not_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::void_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined binary operator type: void != Z"); } - SECTION("left string ==") + SECTION("left string !=") { auto ast = std::make_unique<ASTNode>(); ast->set_type<language::not_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::string_t; - ast->children[1]->m_data_type = ASTNodeDataType::int_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: string != Z"); } - SECTION("right string ==") + SECTION("right string !=") { auto ast = std::make_unique<ASTNode>(); ast->set_type<language::not_eq_op>(); ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children.emplace_back(std::make_unique<ASTNode>()); - ast->children[0]->m_data_type = ASTNodeDataType::int_t; - ast->children[1]->m_data_type = ASTNodeDataType::string_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + ast->children[1]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::string_t>(); - REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, "undefined operand type for binary operator"); + REQUIRE_THROWS_WITH(ASTNodeBinaryOperatorExpressionBuilder{*ast}, + "undefined binary operator type: Z != string"); } } @@ -1907,7 +1940,10 @@ let y : R^1, y = 0; x > y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^1 >= R^1") @@ -1918,7 +1954,10 @@ let y : R^1, y = 0; x >= y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^1 < R^1") @@ -1929,7 +1968,10 @@ let y : R^1, y = 0; x < y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^1 <= R^1") @@ -1940,7 +1982,10 @@ let y : R^1, y = 1; x <= y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^1 * R^1") @@ -1951,7 +1996,10 @@ let y : R^1, y = 0; x * y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^1 / R^1") @@ -1962,7 +2010,10 @@ let y : R^1, y = 0; x / y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^2 > R^2") @@ -1973,7 +2024,10 @@ let y : R^2, y = 0; x > y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^2 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^2 >= R^2") @@ -1984,7 +2038,10 @@ let y : R^2, y = 0; x >= y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^2 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^2 < R^2") @@ -1995,7 +2052,10 @@ let y : R^2, y = 0; x < y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^2 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^2 <= R^2") @@ -2006,7 +2066,10 @@ let y : R^2, y = 0; x <= y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^2 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^2 * R^2") @@ -2017,7 +2080,10 @@ let y : R^2, y = 0; x * y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^2 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^2 / R^2") @@ -2028,7 +2094,10 @@ let y : R^2, y = 0; x / y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^2 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^3 > R^3") @@ -2039,7 +2108,10 @@ let y : R^3, y = 0; x > y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^3)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^3 >= R^3") @@ -2050,7 +2122,10 @@ let y : R^3, y = 0; x >= y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^3)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^3 < R^3") @@ -2061,7 +2136,10 @@ let y : R^3, y = 0; x < y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^3)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^3 <= R^3") @@ -2072,7 +2150,10 @@ let y : R^3, y = 0; x <= y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^3)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^3 * R^3") @@ -2083,7 +2164,10 @@ let y : R^3, y = 0; x * y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^3)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("invalid operator R^3 / R^3") @@ -2094,7 +2178,10 @@ let y : R^3, y = 0; x / y; )"; - REQUIRE_AST_THROWS_WITH(data, "invalid binary operator"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^3)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } } @@ -2108,7 +2195,10 @@ let y : R^1, y = 0; x + y; )"; - REQUIRE_AST_THROWS_WITH(data, "incompatible dimensions of operands"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^1)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("incompatible operand dimensions") @@ -2119,7 +2209,10 @@ let y : R^2, y = 0; x - y; )"; - REQUIRE_AST_THROWS_WITH(data, "incompatible dimensions of operands"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("incompatible operand dimensions") @@ -2130,7 +2223,10 @@ let y : R^2, y = 0; x == y; )"; - REQUIRE_AST_THROWS_WITH(data, "incompatible dimensions of operands"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^3 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } SECTION("incompatible operand dimensions") @@ -2141,7 +2237,10 @@ let y : R^2, y = 0; x != y; )"; - REQUIRE_AST_THROWS_WITH(data, "incompatible dimensions of operands"); + const std::string error_mesh = R"(undefined binary operator +note: incompatible operand types R^1 and R^2)"; + + REQUIRE_AST_THROWS_WITH(data, error_mesh); } } } diff --git a/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp b/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp index 2954cd8db2e691ea2206b30b158f9596a8290081..25115bfc2ff3011e582029ab495e17f4a9d108e2 100644 --- a/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -133,7 +133,7 @@ RtoR(true); } } - SECTION("R -> R1") + SECTION("R -> R^1") { SECTION("from R") { @@ -201,9 +201,77 @@ RtoR1(true); } } - SECTION("R1 -> R") + SECTION("R -> R^1x1") { - SECTION("from R1") + SECTION("from R") + { + std::string_view data = R"( +RtoR11(1.); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::real:1.:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from Z") + { + std::string_view data = R"( +RtoR11(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from N") + { + std::string_view data = R"( +let n : N, n = 1; +RtoR11(n); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::name:n:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from B") + { + std::string_view data = R"( +RtoR11(true); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::true_kw:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^1 -> R") + { + SECTION("from R^1") { std::string_view data = R"( let x : R^1, x = 2; @@ -286,7 +354,92 @@ R1toR(true); } } - SECTION("R2 -> R") + SECTION("R^1x1 -> R") + { + SECTION("from R^1x1") + { + std::string_view data = R"( +let x : R^1x1, x = 2; +R11toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R") + { + std::string_view data = R"( +R11toR(1.); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::real:1.:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from Z") + { + std::string_view data = R"( +R11toR(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from N") + { + std::string_view data = R"( +let n : N, n = 1; +R11toR(n); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::name:n:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from B") + { + std::string_view data = R"( +R11toR(true); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::true_kw:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^2 -> R") { SECTION("from 0") { @@ -340,7 +493,63 @@ R2toR((1,2)); } } - SECTION("R3 -> R") + SECTION("R^2x2 -> R") + { + SECTION("from 0") + { + std::string_view data = R"( +R22toR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R^2x2") + { + std::string_view data = R"( +let x:R^2x2, x = (1,2,3,4); +R22toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from list") + { + std::string_view data = R"( +R22toR((1,2,3,4)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^3 -> R") { SECTION("from 0") { @@ -395,6 +604,67 @@ R3toR((1,2,3)); } } + SECTION("R^3x3 -> R") + { + SECTION("from 0") + { + std::string_view data = R"( +R33toR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R^3x3") + { + std::string_view data = R"( +let x:R^3x3, x = (1,2,3,4,5,6,7,8,9); +R33toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from list") + { + std::string_view data = R"( +R33toR((1,2,3,4,5,6,7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::integer:4:ValueProcessor) + +-(language::integer:5:ValueProcessor) + +-(language::integer:6:ValueProcessor) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + SECTION("Z -> R") { SECTION("from Z") @@ -603,6 +873,87 @@ R3R2toR((1,2,3),0); } } + SECTION("R^3x3*R^2x2 -> R") + { + SECTION("from R^3x3*R^2x2") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let y : R^2x2, y = (1,2,3,4); +R33R22toR(x,y); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::name:x:NameProcessor) + `-(language::name:y:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from (R,R,R,R,R,R,R,R,R)*(R,R,R,R)") + { + std::string_view data = R"( +R33R22toR((1,2,3,4,5,6,7,8,9),(1,2,3,4)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | `-(language::integer:9:ValueProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from (R,R,R,R,R,R,R,R,R)*(0)") + { + std::string_view data = R"( +R33R22toR((1,2,3,4,5,6,7,8,9),0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | `-(language::integer:9:ValueProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + SECTION("string -> B") { std::string_view data = R"( @@ -1012,7 +1363,7 @@ tuple_builtinToB(t); CHECK_AST(data, result); } - SECTION("Z -> tuple(R1)") + SECTION("Z -> tuple(R^1)") { std::string_view data = R"( tuple_R1ToR(1); @@ -1028,7 +1379,7 @@ tuple_R1ToR(1); CHECK_AST(data, result); } - SECTION("R -> tuple(R1)") + SECTION("R -> tuple(R^1)") { std::string_view data = R"( tuple_R1ToR(1.2); @@ -1044,7 +1395,7 @@ tuple_R1ToR(1.2); CHECK_AST(data, result); } - SECTION("R1 -> tuple(R1)") + SECTION("R^1 -> tuple(R^1)") { std::string_view data = R"( let r:R^1, r = 3; @@ -1061,7 +1412,56 @@ tuple_R1ToR(r); CHECK_AST(data, result); } - SECTION("0 -> tuple(R2)") + SECTION("Z -> tuple(R^1x1)") + { + std::string_view data = R"( +tuple_R11ToR(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R -> tuple(R^1x1)") + { + std::string_view data = R"( +tuple_R11ToR(1.2); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::real:1.2:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^1x1 -> tuple(R^1x1)") + { + std::string_view data = R"( +let r:R^1x1, r = 3; +tuple_R11ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("0 -> tuple(R^2)") { std::string_view data = R"( tuple_R2ToR(0); @@ -1077,7 +1477,7 @@ tuple_R2ToR(0); CHECK_AST(data, result); } - SECTION("R2 -> tuple(R2)") + SECTION("R^2 -> tuple(R^2)") { std::string_view data = R"( let r:R^2, r = (1,2); @@ -1094,7 +1494,7 @@ tuple_R2ToR(r); CHECK_AST(data, result); } - SECTION("compound_list -> tuple(R2)") + SECTION("compound_list -> tuple(R^2)") { std::string_view data = R"( let r:R^2, r = (1,2); @@ -1113,7 +1513,59 @@ tuple_R2ToR((r,r)); CHECK_AST(data, result); } - SECTION("0 -> tuple(R3)") + SECTION("0 -> tuple(R^2x2)") + { + std::string_view data = R"( +tuple_R22ToR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^2x2 -> tuple(R^2x2)") + { + std::string_view data = R"( +let r:R^2x2, r = (1,2,3,4); +tuple_R22ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("compound_list -> tuple(R^2x2)") + { + std::string_view data = R"( +let r:R^2x2, r = (1,2,3,4); +tuple_R22ToR((r,r)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::name:r:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("0 -> tuple(R^3)") { std::string_view data = R"( tuple_R3ToR(0); @@ -1129,7 +1581,7 @@ tuple_R3ToR(0); CHECK_AST(data, result); } - SECTION("R3 -> tuple(R3)") + SECTION("R^3 -> tuple(R^3)") { std::string_view data = R"( let r:R^3, r = (1,2,3); @@ -1146,6 +1598,39 @@ tuple_R3ToR(r); CHECK_AST(data, result); } + SECTION("0 -> tuple(R^3x3)") + { + std::string_view data = R"( +tuple_R33ToR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R33ToR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^3x3 -> tuple(R^3x3)") + { + std::string_view data = R"( +let r:R^3x3, r = (1,2,3,4,5,6,7,8,9); +tuple_R33ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R33ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("FunctionSymbolId -> R") { std::string_view data = R"( diff --git a/tests/test_ASTNodeDataType.cpp b/tests/test_ASTNodeDataType.cpp index 5a7b09eb32ace24066a3264d1acb055c33364dc7..4e4e5759b68c3f94a288bc785f313711b41ec7d0 100644 --- a/tests/test_ASTNodeDataType.cpp +++ b/tests/test_ASTNodeDataType.cpp @@ -2,79 +2,105 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeDataType.hpp> +#include <language/utils/ASTNodeDataType.hpp> namespace language { struct integer; struct real; struct vector_type; +struct matrix_type; } // namespace language // clazy:excludeall=non-pod-global-static TEST_CASE("ASTNodeDataType", "[language]") { + const ASTNodeDataType undefined_dt = ASTNodeDataType{}; + const ASTNodeDataType bool_dt = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + const ASTNodeDataType unsigned_int_dt = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + const ASTNodeDataType int_dt = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + const ASTNodeDataType double_dt = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + const ASTNodeDataType string_dt = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + const ASTNodeDataType void_dt = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + const ASTNodeDataType function_dt = ASTNodeDataType::build<ASTNodeDataType::function_t>(); + const ASTNodeDataType builtin_function_dt = ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>(); + + std::vector<std::shared_ptr<const ASTNodeDataType>> type_list; + type_list.push_back(std::make_shared<const ASTNodeDataType>(double_dt)); + type_list.push_back(std::make_shared<const ASTNodeDataType>(int_dt)); + + const ASTNodeDataType list_dt = ASTNodeDataType::build<ASTNodeDataType::list_t>(type_list); + + const ASTNodeDataType empty_list_dt = + ASTNodeDataType::build<ASTNodeDataType::list_t>(std::vector<std::shared_ptr<const ASTNodeDataType>>{}); + SECTION("dataTypeName") { - REQUIRE(dataTypeName(ASTNodeDataType::undefined_t) == "undefined"); - REQUIRE(dataTypeName(ASTNodeDataType::bool_t) == "B"); - REQUIRE(dataTypeName(ASTNodeDataType::unsigned_int_t) == "N"); - REQUIRE(dataTypeName(ASTNodeDataType::int_t) == "Z"); - REQUIRE(dataTypeName(ASTNodeDataType::double_t) == "R"); - REQUIRE(dataTypeName(ASTNodeDataType::string_t) == "string"); - REQUIRE(dataTypeName(ASTNodeDataType::typename_t) == "typename"); - REQUIRE(dataTypeName(ASTNodeDataType::void_t) == "void"); - REQUIRE(dataTypeName(ASTNodeDataType::function_t) == "function"); - REQUIRE(dataTypeName(ASTNodeDataType::builtin_function_t) == "builtin_function"); - REQUIRE(dataTypeName(ASTNodeDataType::list_t) == "list"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}) == - "tuple(B)"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}) == - "tuple(N)"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}) == - "tuple(Z)"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}) == - "tuple(R)"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}) == - "tuple(R)"); - - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::type_name_id_t, 1}) == "type_name_id"); - - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::type_id_t, "user_type"}) == "user_type"); - - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::vector_t, 1}) == "R^1"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::vector_t, 2}) == "R^2"); - REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::vector_t, 3}) == "R^3"); + REQUIRE(dataTypeName(undefined_dt) == "undefined"); + REQUIRE(dataTypeName(bool_dt) == "B"); + REQUIRE(dataTypeName(unsigned_int_dt) == "N"); + REQUIRE(dataTypeName(int_dt) == "Z"); + REQUIRE(dataTypeName(double_dt) == "R"); + REQUIRE(dataTypeName(string_dt) == "string"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::typename_t>(double_dt)) == "typename(R)"); + REQUIRE(dataTypeName(void_dt) == "void"); + REQUIRE(dataTypeName(function_dt) == "function"); + REQUIRE(dataTypeName(builtin_function_dt) == "builtin_function"); + REQUIRE(dataTypeName(list_dt) == "list(R*Z)"); + REQUIRE(dataTypeName(empty_list_dt) == "list(void)"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt)) == "tuple(B)"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt)) == "tuple(N)"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt)) == "tuple(Z)"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt)) == "tuple(R)"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt)) == "tuple(Z)"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>()) == "type_name_id"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::type_id_t>("user_type")) == "user_type"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)) == "R^1"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)) == "R^2"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)) == "R^3"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(7)) == "R^7"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)) == "R^1x1"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)) == "R^2x2"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)) == "R^3x3"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(7, 3)) == "R^7x3"); + + REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{}) == "void"); + REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{bool_dt}) == "B"); + REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{bool_dt, unsigned_int_dt}) == "(B,N)"); } SECTION("promotion") { - REQUIRE(dataTypePromotion(ASTNodeDataType::undefined_t, ASTNodeDataType::undefined_t) == - ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::void_t, ASTNodeDataType::double_t) == ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::undefined_t) == ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::bool_t) == ASTNodeDataType::double_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::unsigned_int_t) == ASTNodeDataType::double_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::int_t) == ASTNodeDataType::double_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::unsigned_int_t) == - ASTNodeDataType::unsigned_int_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::bool_t) == ASTNodeDataType::int_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::bool_t) == - ASTNodeDataType::unsigned_int_t); - - REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::bool_t) == ASTNodeDataType::string_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::int_t) == ASTNodeDataType::string_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::unsigned_int_t) == ASTNodeDataType::string_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::double_t) == ASTNodeDataType::string_t); - - REQUIRE(dataTypePromotion(ASTNodeDataType::bool_t, ASTNodeDataType::string_t) == ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::string_t) == ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::string_t) == - ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::string_t) == ASTNodeDataType::undefined_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::bool_t, ASTNodeDataType::vector_t) == ASTNodeDataType::vector_t); - REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::vector_t) == ASTNodeDataType::vector_t); + const ASTNodeDataType vector_dt = ASTNodeDataType::build<ASTNodeDataType::vector_t>(5); + + REQUIRE(dataTypePromotion(undefined_dt, undefined_dt) == undefined_dt); + REQUIRE(dataTypePromotion(void_dt, double_dt) == undefined_dt); + REQUIRE(dataTypePromotion(int_dt, undefined_dt) == undefined_dt); + REQUIRE(dataTypePromotion(double_dt, bool_dt) == double_dt); + REQUIRE(dataTypePromotion(double_dt, unsigned_int_dt) == double_dt); + REQUIRE(dataTypePromotion(double_dt, int_dt) == double_dt); + REQUIRE(dataTypePromotion(int_dt, unsigned_int_dt) == unsigned_int_dt); + REQUIRE(dataTypePromotion(int_dt, bool_dt) == int_dt); + REQUIRE(dataTypePromotion(unsigned_int_dt, bool_dt) == unsigned_int_dt); + + REQUIRE(dataTypePromotion(string_dt, bool_dt) == string_dt); + REQUIRE(dataTypePromotion(string_dt, int_dt) == string_dt); + REQUIRE(dataTypePromotion(string_dt, unsigned_int_dt) == string_dt); + REQUIRE(dataTypePromotion(string_dt, double_dt) == string_dt); + + REQUIRE(dataTypePromotion(bool_dt, string_dt) == undefined_dt); + REQUIRE(dataTypePromotion(int_dt, string_dt) == undefined_dt); + REQUIRE(dataTypePromotion(unsigned_int_dt, string_dt) == undefined_dt); + REQUIRE(dataTypePromotion(double_dt, string_dt) == undefined_dt); + REQUIRE(dataTypePromotion(bool_dt, vector_dt) == vector_dt); + REQUIRE(dataTypePromotion(double_dt, vector_dt) == vector_dt); } SECTION("getVectorDataType") @@ -87,7 +113,7 @@ TEST_CASE("ASTNodeDataType", "[language]") { std::unique_ptr dimension_node = std::make_unique<ASTNode>(); dimension_node->set_type<language::integer>(); - dimension_node->source = "17"; + dimension_node->source = "3"; auto& source = dimension_node->source; dimension_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]}; dimension_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]}; @@ -96,8 +122,8 @@ TEST_CASE("ASTNodeDataType", "[language]") SECTION("good node") { - REQUIRE(getVectorDataType(*type_node) == ASTNodeDataType::vector_t); - REQUIRE(getVectorDataType(*type_node).dimension() == 17); + REQUIRE(getVectorDataType(*type_node) == ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + REQUIRE(getVectorDataType(*type_node).dimension() == 3); } SECTION("bad node type") @@ -123,148 +149,299 @@ TEST_CASE("ASTNodeDataType", "[language]") type_node->children[1]->set_type<language::real>(); REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected non integer constant dimension"); } + + SECTION("bad dimension value") + { + type_node->children[1]->source = "0"; + REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + + type_node->children[1]->source = "4"; + REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + } + } + + SECTION("getMatrixDataType") + { + std::unique_ptr type_node = std::make_unique<ASTNode>(); + type_node->set_type<language::matrix_type>(); + + type_node->emplace_back(std::make_unique<ASTNode>()); + + { + { + std::unique_ptr dimension0_node = std::make_unique<ASTNode>(); + dimension0_node->set_type<language::integer>(); + dimension0_node->source = "3"; + auto& source0 = dimension0_node->source; + dimension0_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source0[0]}; + dimension0_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source0[source0.size()]}; + type_node->emplace_back(std::move(dimension0_node)); + } + { + std::unique_ptr dimension1_node = std::make_unique<ASTNode>(); + dimension1_node->set_type<language::integer>(); + dimension1_node->source = "3"; + auto& source1 = dimension1_node->source; + dimension1_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source1[0]}; + dimension1_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source1[source1.size()]}; + type_node->emplace_back(std::move(dimension1_node)); + } + } + + SECTION("good node") + { + REQUIRE(getMatrixDataType(*type_node) == ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + REQUIRE(getMatrixDataType(*type_node).nbRows() == 3); + REQUIRE(getMatrixDataType(*type_node).nbColumns() == 3); + } + + SECTION("bad node type") + { + type_node->set_type<language::integer>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + type_node->children.clear(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + type_node->children.emplace_back(std::unique_ptr<ASTNode>()); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad dimension 0 type") + { + type_node->children[1]->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected non integer constant dimension"); + } + + SECTION("bad dimension 1 type") + { + type_node->children[2]->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected non integer constant dimension"); + } + + SECTION("bad nb rows value") + { + type_node->children[1]->source = "0"; + type_node->children[2]->source = "0"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + + type_node->children[1]->source = "4"; + type_node->children[2]->source = "4"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + } + + SECTION("none square matrices") + { + type_node->children[1]->source = "1"; + type_node->children[2]->source = "2"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "only square matrices are supported"); + } } SECTION("isNaturalConversion") { SECTION("-> B") { - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::bool_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::bool_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::bool_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::bool_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::bool_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::bool_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::bool_t)); + REQUIRE(isNaturalConversion(bool_dt, bool_dt)); + REQUIRE(not isNaturalConversion(unsigned_int_dt, bool_dt)); + REQUIRE(not isNaturalConversion(int_dt, bool_dt)); + REQUIRE(not isNaturalConversion(double_dt, bool_dt)); + REQUIRE(not isNaturalConversion(string_dt, bool_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt), bool_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), bool_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), bool_dt)); } SECTION("-> N") { - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::unsigned_int_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::unsigned_int_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::unsigned_int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::unsigned_int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::unsigned_int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::unsigned_int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::unsigned_int_t)); + REQUIRE(isNaturalConversion(bool_dt, unsigned_int_dt)); + REQUIRE(isNaturalConversion(unsigned_int_dt, unsigned_int_dt)); + REQUIRE(isNaturalConversion(int_dt, unsigned_int_dt)); + REQUIRE(not isNaturalConversion(double_dt, unsigned_int_dt)); + REQUIRE(not isNaturalConversion(string_dt, unsigned_int_dt)); + REQUIRE( + not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt), unsigned_int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), unsigned_int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), unsigned_int_dt)); } SECTION("-> Z") { - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::int_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::int_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::int_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::int_t)); + REQUIRE(isNaturalConversion(bool_dt, int_dt)); + REQUIRE(isNaturalConversion(unsigned_int_dt, int_dt)); + REQUIRE(isNaturalConversion(int_dt, int_dt)); + REQUIRE(not isNaturalConversion(double_dt, int_dt)); + REQUIRE(not isNaturalConversion(string_dt, int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt), int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), int_dt)); } SECTION("-> R") { - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::double_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::double_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::double_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::double_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::double_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::double_t)); - REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::double_t)); + REQUIRE(isNaturalConversion(bool_dt, double_dt)); + REQUIRE(isNaturalConversion(unsigned_int_dt, double_dt)); + REQUIRE(isNaturalConversion(int_dt, double_dt)); + REQUIRE(isNaturalConversion(double_dt, double_dt)); + REQUIRE(not isNaturalConversion(string_dt, double_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), double_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), double_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), double_dt)); } SECTION("-> string") { - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::string_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::string_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::string_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::string_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::string_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::string_t)); - REQUIRE(isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::string_t)); + REQUIRE(isNaturalConversion(bool_dt, string_dt)); + REQUIRE(isNaturalConversion(unsigned_int_dt, string_dt)); + REQUIRE(isNaturalConversion(int_dt, string_dt)); + REQUIRE(isNaturalConversion(double_dt, string_dt)); + REQUIRE(isNaturalConversion(string_dt, string_dt)); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt), string_dt)); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), string_dt)); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), string_dt)); } SECTION("-> tuple") { - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}})); - REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::unsigned_int_t}})); - REQUIRE( - not isNaturalConversion(ASTNodeDataType::int_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}})); - - REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::unsigned_int_t}})); - REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::unsigned_int_t}})); - - REQUIRE( - isNaturalConversion(ASTNodeDataType::bool_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}})); - REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::double_t}})); - REQUIRE( - isNaturalConversion(ASTNodeDataType::unsigned_int_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}})); - REQUIRE( - isNaturalConversion(ASTNodeDataType::double_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}})); - REQUIRE( - not isNaturalConversion(ASTNodeDataType::string_t, - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}})); + REQUIRE(isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt))); + REQUIRE(isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt))); + REQUIRE(not isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt))); + + REQUIRE(isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt))); + REQUIRE(isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt))); + REQUIRE(not isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt))); + + REQUIRE(isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt))); + REQUIRE(isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt))); + REQUIRE(isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt))); + REQUIRE(isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt))); + REQUIRE(not isNaturalConversion(string_dt, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt))); } SECTION("-> vector") { - REQUIRE(not isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType{ASTNodeDataType::vector_t, 3})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType{ASTNodeDataType::vector_t, 2})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType{ASTNodeDataType::vector_t, 2})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType{ASTNodeDataType::vector_t, 3})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1})); - - REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 1}, - ASTNodeDataType{ASTNodeDataType::vector_t, 1})); - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 2}, - ASTNodeDataType{ASTNodeDataType::vector_t, 1})); - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 3}, - ASTNodeDataType{ASTNodeDataType::vector_t, 1})); - - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 1}, - ASTNodeDataType{ASTNodeDataType::vector_t, 2})); - REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 2}, - ASTNodeDataType{ASTNodeDataType::vector_t, 2})); - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 3}, - ASTNodeDataType{ASTNodeDataType::vector_t, 2})); - - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 1}, - ASTNodeDataType{ASTNodeDataType::vector_t, 3})); - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 2}, - ASTNodeDataType{ASTNodeDataType::vector_t, 3})); - REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 3}, - ASTNodeDataType{ASTNodeDataType::vector_t, 3})); + REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(not isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(not isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(not isNaturalConversion(string_dt, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(4))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(9))); + + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + } + + SECTION("-> matrix") + { + REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(string_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(4), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(9), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); } SECTION("-> type_id") { - REQUIRE(not isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - REQUIRE( - not isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - - REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}, - ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"})); - - REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}, - ASTNodeDataType{ASTNodeDataType::type_id_t, "bar"})); + REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + REQUIRE(not isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + REQUIRE(not isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + REQUIRE(not isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + REQUIRE(not isNaturalConversion(string_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt), + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"), + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"), + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("bar"))); } } + +#ifndef NDEBUG + SECTION("errors") + { + REQUIRE_THROWS_AS(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(undefined_dt), AssertError); + } +#endif // NDEBUG } diff --git a/tests/test_ASTNodeDataTypeBuilder.cpp b/tests/test_ASTNodeDataTypeBuilder.cpp index c5c691a28db39749f1910521b875eca0af0e8271..369374986fa27f69b12e147434ab3605872cac16 100644 --- a/tests/test_ASTNodeDataTypeBuilder.cpp +++ b/tests/test_ASTNodeDataTypeBuilder.cpp @@ -29,8 +29,8 @@ } template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = {ASTNodeDataType::type_id_t, - "builtin_t"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t"); const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const double>>; #define CHECK_AST_WITH_BUILTIN(data, expected_output) \ @@ -47,7 +47,7 @@ const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const dou throw UnexpectedError("cannot add '" + builtin_data_type.nameOfTypeId() + "' type for testing"); \ } \ \ - i_symbol->attributes().setDataType(ASTNodeDataType::type_name_id_t); \ + i_symbol->attributes().setDataType(ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>()); \ i_symbol->attributes().setIsInitialized(); \ i_symbol->attributes().value() = symbol_table.typeEmbedderTable().size(); \ symbol_table.typeEmbedderTable().add(std::make_shared<TypeDescriptor>(builtin_data_type.nameOfTypeId())); \ @@ -277,17 +277,17 @@ let (x,b,n,s) : R*B*N*string; std::string_view result = R"( (root:void) - `-(language::var_declaration:typename) - +-(language::name_list:list) + `-(language::var_declaration:void) + +-(language::name_list:list(R*B*N*string)) | +-(language::name:x:R) | +-(language::name:b:B) | +-(language::name:n:N) | `-(language::name:s:string) - `-(language::type_expression:typename) - +-(language::R_set:typename) - +-(language::B_set:typename) - +-(language::N_set:typename) - `-(language::string_type:typename) + `-(language::type_expression:typename(list(typename(R)*typename(B)*typename(N)*typename(string)))) + +-(language::R_set:typename(R)) + +-(language::B_set:typename(B)) + +-(language::N_set:typename(N)) + `-(language::string_type:typename(string)) )"; CHECK_AST(data, result); @@ -305,7 +305,46 @@ let x : R; x[2]; auto ast = ASTBuilder::build(input); ASTSymbolTableBuilder{*ast}; - REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid types 'R[Z]' for array subscript"); + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid subscript expression: R cannot be indexed"); + } + + SECTION("invalid R^d subscript index list") + { + std::string_view data = R"( +let x : R^2; x[2,2]; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid index type: R^2 requires a single integer"); + } + + SECTION("invalid R^dxd subscript index list 1") + { + std::string_view data = R"( +let x : R^2x2; x[2]; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid index type: R^2x2 requires two integers"); + } + + SECTION("invalid R^dxd subscript index list 2") + { + std::string_view data = R"( +let x : R^2x2; x[2,3,1]; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "invalid index type: R^2x2 requires two integers"); } SECTION("too many variables") @@ -377,12 +416,12 @@ let t : (B), t = (true, false); std::string_view result = R"( (root:void) - `-(language::var_declaration:tuple(B)) + `-(language::var_declaration:void) +-(language::name:t:tuple(B)) - +-(language::tuple_type_specifier:tuple(B)) - | `-(language::B_set:typename) + +-(language::tuple_type_specifier:typename(tuple(B))) + | `-(language::B_set:typename(B)) +-(language::name:t:tuple(B)) - `-(language::expression_list:list) + `-(language::expression_list:list(B*B)) +-(language::true_kw:B) `-(language::false_kw:B) )"; @@ -398,12 +437,12 @@ let t : (N), t = (1, 2, 3, 5); std::string_view result = R"( (root:void) - `-(language::var_declaration:tuple(N)) + `-(language::var_declaration:void) +-(language::name:t:tuple(N)) - +-(language::tuple_type_specifier:tuple(N)) - | `-(language::N_set:typename) + +-(language::tuple_type_specifier:typename(tuple(N))) + | `-(language::N_set:typename(N)) +-(language::name:t:tuple(N)) - `-(language::expression_list:list) + `-(language::expression_list:list(Z*Z*Z*Z)) +-(language::integer:1:Z) +-(language::integer:2:Z) +-(language::integer:3:Z) @@ -422,17 +461,17 @@ let t : (Z), t = (2, n, true); std::string_view result = R"( (root:void) - +-(language::var_declaration:N) + +-(language::var_declaration:void) | +-(language::name:n:N) - | +-(language::N_set:typename) + | +-(language::N_set:typename(N)) | +-(language::name:n:N) | `-(language::integer:3:Z) - `-(language::var_declaration:tuple(Z)) + `-(language::var_declaration:void) +-(language::name:t:tuple(Z)) - +-(language::tuple_type_specifier:tuple(Z)) - | `-(language::Z_set:typename) + +-(language::tuple_type_specifier:typename(tuple(Z))) + | `-(language::Z_set:typename(Z)) +-(language::name:t:tuple(Z)) - `-(language::expression_list:list) + `-(language::expression_list:list(Z*N*B)) +-(language::integer:2:Z) +-(language::name:n:N) `-(language::true_kw:B) @@ -449,12 +488,12 @@ let t : (R), t = (2, 3.1, 5); std::string_view result = R"( (root:void) - `-(language::var_declaration:tuple(R)) + `-(language::var_declaration:void) +-(language::name:t:tuple(R)) - +-(language::tuple_type_specifier:tuple(R)) - | `-(language::R_set:typename) + +-(language::tuple_type_specifier:typename(tuple(R))) + | `-(language::R_set:typename(R)) +-(language::name:t:tuple(R)) - `-(language::expression_list:list) + `-(language::expression_list:list(Z*R*Z)) +-(language::integer:2:Z) +-(language::real:3.1:R) `-(language::integer:5:Z) @@ -473,36 +512,91 @@ let t2 : (R^3), t2 = (0, 0); std::string_view result = R"( (root:void) - +-(language::var_declaration:R^2) + +-(language::var_declaration:void) | +-(language::name:a:R^2) - | +-(language::vector_type:typename) - | | +-(language::R_set:typename) + | +-(language::vector_type:typename(R^2)) + | | +-(language::R_set:typename(R)) | | `-(language::integer:2:Z) | +-(language::name:a:R^2) - | `-(language::expression_list:list) + | `-(language::expression_list:list(Z*R)) | +-(language::integer:2:Z) | `-(language::real:3.1:R) - +-(language::var_declaration:tuple(R^2)) + +-(language::var_declaration:void) | +-(language::name:t1:tuple(R^2)) - | +-(language::tuple_type_specifier:tuple(R^2)) - | | `-(language::vector_type:typename) - | | +-(language::R_set:typename) + | +-(language::tuple_type_specifier:typename(tuple(R^2))) + | | `-(language::vector_type:typename(R^2)) + | | +-(language::R_set:typename(R)) | | `-(language::integer:2:Z) | +-(language::name:t1:tuple(R^2)) - | `-(language::expression_list:list) + | `-(language::expression_list:list(R^2*list(Z*Z)*Z)) | +-(language::name:a:R^2) - | +-(language::tuple_expression:list) + | +-(language::tuple_expression:list(Z*Z)) | | +-(language::integer:1:Z) | | `-(language::integer:2:Z) | `-(language::integer:0:Z) - `-(language::var_declaration:tuple(R^3)) + `-(language::var_declaration:void) +-(language::name:t2:tuple(R^3)) - +-(language::tuple_type_specifier:tuple(R^3)) - | `-(language::vector_type:typename) - | +-(language::R_set:typename) + +-(language::tuple_type_specifier:typename(tuple(R^3))) + | `-(language::vector_type:typename(R^3)) + | +-(language::R_set:typename(R)) | `-(language::integer:3:Z) +-(language::name:t2:tuple(R^3)) - `-(language::expression_list:list) + `-(language::expression_list:list(Z*Z)) + +-(language::integer:0:Z) + `-(language::integer:0:Z) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^dxd tuples") + { + std::string_view data = R"( +let a : R^2x2, a = (2, 3.1, -1.2, 4); +let t1 : (R^2x2), t1 = (a, (1,2,1,3), 0); +let t2 : (R^3x3), t2 = (0, 0); +)"; + + std::string_view result = R"( +(root:void) + +-(language::var_declaration:void) + | +-(language::name:a:R^2x2) + | +-(language::matrix_type:typename(R^2x2)) + | | +-(language::R_set:typename(R)) + | | +-(language::integer:2:Z) + | | `-(language::integer:2:Z) + | +-(language::name:a:R^2x2) + | `-(language::expression_list:list(Z*R*R*Z)) + | +-(language::integer:2:Z) + | +-(language::real:3.1:R) + | +-(language::unary_minus:R) + | | `-(language::real:1.2:R) + | `-(language::integer:4:Z) + +-(language::var_declaration:void) + | +-(language::name:t1:tuple(R^2x2)) + | +-(language::tuple_type_specifier:typename(tuple(R^2x2))) + | | `-(language::matrix_type:typename(R^2x2)) + | | +-(language::R_set:typename(R)) + | | +-(language::integer:2:Z) + | | `-(language::integer:2:Z) + | +-(language::name:t1:tuple(R^2x2)) + | `-(language::expression_list:list(R^2x2*list(Z*Z*Z*Z)*Z)) + | +-(language::name:a:R^2x2) + | +-(language::tuple_expression:list(Z*Z*Z*Z)) + | | +-(language::integer:1:Z) + | | +-(language::integer:2:Z) + | | +-(language::integer:1:Z) + | | `-(language::integer:3:Z) + | `-(language::integer:0:Z) + `-(language::var_declaration:void) + +-(language::name:t2:tuple(R^3x3)) + +-(language::tuple_type_specifier:typename(tuple(R^3x3))) + | `-(language::matrix_type:typename(R^3x3)) + | +-(language::R_set:typename(R)) + | +-(language::integer:3:Z) + | `-(language::integer:3:Z) + +-(language::name:t2:tuple(R^3x3)) + `-(language::expression_list:list(Z*Z)) +-(language::integer:0:Z) `-(language::integer:0:Z) )"; @@ -518,12 +612,12 @@ let t : (string), t = ("foo", "bar"); std::string_view result = R"( (root:void) - `-(language::var_declaration:tuple(string)) + `-(language::var_declaration:void) +-(language::name:t:tuple(string)) - +-(language::tuple_type_specifier:tuple(string)) - | `-(language::string_type:typename) + +-(language::tuple_type_specifier:typename(tuple(string))) + | `-(language::string_type:typename(string)) +-(language::name:t:tuple(string)) - `-(language::expression_list:list) + `-(language::expression_list:list(string*string)) +-(language::literal:"foo":string) `-(language::literal:"bar":string) )"; @@ -540,12 +634,12 @@ let t : (builtin_t), t= (1,2,3); std::string_view result = R"( (root:void) - `-(language::var_declaration:tuple(builtin_t)) + `-(language::var_declaration:void) +-(language::name:t:tuple(builtin_t)) - +-(language::tuple_type_specifier:tuple(builtin_t)) - | `-(language::type_name_id:typename) + +-(language::tuple_type_specifier:typename(tuple(builtin_t))) + | `-(language::type_name_id:builtin_t) +-(language::name:t:tuple(builtin_t)) - `-(language::expression_list:list) + `-(language::expression_list:list(Z*Z*Z)) +-(language::integer:1:Z) +-(language::integer:2:Z) `-(language::integer:3:Z) @@ -647,6 +741,84 @@ let square : R -> R^2, x -> (x, x*x); } } + SECTION("R^dxd-functions") + { + SECTION("matrix function") + { + std::string_view data = R"( +let double : R^2x2 -> R^2x2, x -> 2*x; +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:double:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("matrix vector product") + { + std::string_view data = R"( +let prod : R^2x2*R^2 -> R^2, (A,x) -> A*x; +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:prod:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("matrix function") + { + std::string_view data = R"( +let det : R^2x2 -> R, x -> x[0,0]*x[1,1]-x[1,0]*x[0,1]; +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:det:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("R-list -> R^dxd") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> (x, x*x, 2-x, 0); +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:f:function) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^d*R^d -> R^dxd") + { + std::string_view data = R"( +let f : R^2*R^2 -> R^2x2, (x,y) -> (x[0], y[0], x[1], y[1]); +)"; + + std::string_view result = R"( +(root:void) + `-(language::fct_declaration:void) + `-(language::name:f:function) +)"; + + CHECK_AST(data, result); + } + } + SECTION("R-functions") { SECTION("multiple variable") @@ -866,6 +1038,19 @@ let f : R -> R*R, x -> x*x*x; REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, "number of image spaces (2) R*R differs from number of expressions (1) x*x*x"); } + + SECTION("wrong image size 3") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> (x, 2*x, 2); +)"; + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, + "expecting 4 scalar expressions or an R^2x2, found 3 scalar expressions"); + } } } @@ -885,16 +1070,16 @@ x = f(x); (root:void) +-(language::fct_declaration:void) | `-(language::name:f:function) - +-(language::var_declaration:R^2) + +-(language::var_declaration:void) | +-(language::name:x:R^2) - | +-(language::vector_type:typename) - | | +-(language::R_set:typename) + | +-(language::vector_type:typename(R^2)) + | | +-(language::R_set:typename(R)) | | `-(language::integer:2:Z) | +-(language::name:x:R^2) - | `-(language::expression_list:list) + | `-(language::expression_list:list(Z*Z)) | +-(language::integer:1:Z) | `-(language::integer:2:Z) - `-(language::eq_op:R^2) + `-(language::eq_op:void) +-(language::name:x:R^2) `-(language::function_evaluation:R^2) +-(language::name:f:function) @@ -918,9 +1103,9 @@ let x : R, x = incr(3); (root:void) +-(language::fct_declaration:void) | `-(language::name:incr:function) - `-(language::var_declaration:R) + `-(language::var_declaration:void) +-(language::name:x:R) - +-(language::R_set:typename) + +-(language::R_set:typename(R)) +-(language::name:x:R) `-(language::function_evaluation:R) +-(language::name:incr:function) @@ -941,13 +1126,13 @@ let diff : R, diff = substract(3,2); (root:void) +-(language::fct_declaration:void) | `-(language::name:substract:function) - `-(language::var_declaration:R) + `-(language::var_declaration:void) +-(language::name:diff:R) - +-(language::R_set:typename) + +-(language::R_set:typename(R)) +-(language::name:diff:R) `-(language::function_evaluation:R) +-(language::name:substract:function) - `-(language::function_argument_list:list) + `-(language::function_argument_list:list(Z*Z)) +-(language::integer:3:Z) `-(language::integer:2:Z) )"; @@ -967,9 +1152,9 @@ let z : Z, z = incr(3); (root:void) +-(language::fct_declaration:void) | `-(language::name:incr:function) - `-(language::var_declaration:Z) + `-(language::var_declaration:void) +-(language::name:z:Z) - +-(language::Z_set:typename) + +-(language::Z_set:typename(Z)) +-(language::name:z:Z) `-(language::function_evaluation:Z) +-(language::name:incr:function) @@ -990,9 +1175,9 @@ let n : N, n = double(3); (root:void) +-(language::fct_declaration:void) | `-(language::name:double:function) - `-(language::var_declaration:N) + `-(language::var_declaration:void) +-(language::name:n:N) - +-(language::N_set:typename) + +-(language::N_set:typename(N)) +-(language::name:n:N) `-(language::function_evaluation:N) +-(language::name:double:function) @@ -1013,9 +1198,9 @@ let b : B, b = greater_than_2(3); (root:void) +-(language::fct_declaration:void) | `-(language::name:greater_than_2:function) - `-(language::var_declaration:B) + `-(language::var_declaration:void) +-(language::name:b:B) - +-(language::B_set:typename) + +-(language::B_set:typename(B)) +-(language::name:b:B) `-(language::function_evaluation:B) +-(language::name:greater_than_2:function) @@ -1036,13 +1221,13 @@ let s : string, s = cat("foo", "bar"); (root:void) +-(language::fct_declaration:void) | `-(language::name:cat:function) - `-(language::var_declaration:string) + `-(language::var_declaration:void) +-(language::name:s:string) - +-(language::string_type:typename) + +-(language::string_type:typename(string)) +-(language::name:s:string) `-(language::function_evaluation:string) +-(language::name:cat:function) - `-(language::function_argument_list:list) + `-(language::function_argument_list:list(string*string)) +-(language::literal:"foo":string) `-(language::literal:"bar":string) )"; @@ -1061,17 +1246,17 @@ let (x,x2) : R*R, (x,x2) = x_x2(3); (root:void) +-(language::fct_declaration:void) | `-(language::name:x_x2:function) - `-(language::var_declaration:typename) - +-(language::name_list:list) + `-(language::var_declaration:void) + +-(language::name_list:list(R*R)) | +-(language::name:x:R) | `-(language::name:x2:R) - +-(language::type_expression:typename) - | +-(language::R_set:typename) - | `-(language::R_set:typename) - +-(language::name_list:list) + +-(language::type_expression:typename(list(typename(R)*typename(R)))) + | +-(language::R_set:typename(R)) + | `-(language::R_set:typename(R)) + +-(language::name_list:list(R*R)) | +-(language::name:x:R) | `-(language::name:x2:R) - `-(language::function_evaluation:typename) + `-(language::function_evaluation:list(typename(R)*typename(R))) +-(language::name:x_x2:function) `-(language::integer:3:Z) )"; @@ -1128,9 +1313,9 @@ for (let i : N, i=0; i<3; ++i){ std::string_view result = R"( (root:void) `-(language::for_statement:void) - +-(language::var_declaration:N) + +-(language::var_declaration:void) | +-(language::name:i:N) - | +-(language::N_set:typename) + | +-(language::N_set:typename(N)) | +-(language::name:i:N) | `-(language::integer:0:Z) +-(language::lesser_op:B) @@ -1154,9 +1339,9 @@ let b:B; std::string_view result = R"( (root:void) - `-(language::var_declaration:B) + `-(language::var_declaration:void) +-(language::name:b:B) - `-(language::B_set:typename) + `-(language::B_set:typename(B)) )"; CHECK_AST(data, result); @@ -1170,9 +1355,9 @@ let n :N; std::string_view result = R"( (root:void) - `-(language::var_declaration:N) + `-(language::var_declaration:void) +-(language::name:n:N) - `-(language::N_set:typename) + `-(language::N_set:typename(N)) )"; CHECK_AST(data, result); @@ -1186,9 +1371,9 @@ let z:Z; std::string_view result = R"( (root:void) - `-(language::var_declaration:Z) + `-(language::var_declaration:void) +-(language::name:z:Z) - `-(language::Z_set:typename) + `-(language::Z_set:typename(Z)) )"; CHECK_AST(data, result); @@ -1202,9 +1387,9 @@ let r:R; std::string_view result = R"( (root:void) - `-(language::var_declaration:R) + `-(language::var_declaration:void) +-(language::name:r:R) - `-(language::R_set:typename) + `-(language::R_set:typename(R)) )"; CHECK_AST(data, result); @@ -1218,9 +1403,9 @@ let s: string; std::string_view result = R"( (root:void) - `-(language::var_declaration:string) + `-(language::var_declaration:void) +-(language::name:s:string) - `-(language::string_type:typename) + `-(language::string_type:typename(string)) )"; CHECK_AST(data, result); @@ -1235,9 +1420,9 @@ let t : builtin_t, t= 1; std::string_view result = R"( (root:void) - `-(language::var_declaration:builtin_t) + `-(language::var_declaration:void) +-(language::name:t:builtin_t) - +-(language::type_name_id:typename) + +-(language::type_name_id:typename(builtin_t)) +-(language::name:t:builtin_t) `-(language::integer:1:Z) )"; @@ -1282,10 +1467,10 @@ a = 1; std::string_view result = R"( (root:void) - +-(language::var_declaration:N) + +-(language::var_declaration:void) | +-(language::name:a:N) - | `-(language::N_set:typename) - `-(language::eq_op:N) + | `-(language::N_set:typename(N)) + `-(language::eq_op:void) +-(language::name:a:N) `-(language::integer:1:Z) )"; @@ -1302,12 +1487,12 @@ a *= 1.2; std::string_view result = R"( (root:void) - +-(language::var_declaration:N) + +-(language::var_declaration:void) | +-(language::name:a:N) - | +-(language::N_set:typename) + | +-(language::N_set:typename(N)) | +-(language::name:a:N) | `-(language::integer:1:Z) - `-(language::multiplyeq_op:N) + `-(language::multiplyeq_op:void) +-(language::name:a:N) `-(language::real:1.2:R) )"; @@ -1324,12 +1509,12 @@ a /= 2; std::string_view result = R"( (root:void) - +-(language::var_declaration:R) + +-(language::var_declaration:void) | +-(language::name:a:R) - | +-(language::R_set:typename) + | +-(language::R_set:typename(R)) | +-(language::name:a:R) | `-(language::integer:3:Z) - `-(language::divideeq_op:R) + `-(language::divideeq_op:void) +-(language::name:a:R) `-(language::integer:2:Z) )"; @@ -1346,12 +1531,12 @@ a += 2; std::string_view result = R"( (root:void) - +-(language::var_declaration:Z) + +-(language::var_declaration:void) | +-(language::name:a:Z) - | +-(language::Z_set:typename) + | +-(language::Z_set:typename(Z)) | +-(language::name:a:Z) | `-(language::integer:3:Z) - `-(language::pluseq_op:Z) + `-(language::pluseq_op:void) +-(language::name:a:Z) `-(language::integer:2:Z) )"; @@ -1368,12 +1553,12 @@ a -= 2; std::string_view result = R"( (root:void) - +-(language::var_declaration:Z) + +-(language::var_declaration:void) | +-(language::name:a:Z) - | +-(language::Z_set:typename) + | +-(language::Z_set:typename(Z)) | +-(language::name:a:Z) | `-(language::integer:1:Z) - `-(language::minuseq_op:Z) + `-(language::minuseq_op:void) +-(language::name:a:Z) `-(language::integer:2:Z) )"; @@ -1408,18 +1593,18 @@ for (let i:Z, i=0; i<3; i += 1) { i += 2; } std::string_view result = R"( (root:void) `-(language::for_statement:void) - +-(language::var_declaration:Z) + +-(language::var_declaration:void) | +-(language::name:i:Z) - | +-(language::Z_set:typename) + | +-(language::Z_set:typename(Z)) | +-(language::name:i:Z) | `-(language::integer:0:Z) +-(language::lesser_op:B) | +-(language::name:i:Z) | `-(language::integer:3:Z) - +-(language::pluseq_op:Z) + +-(language::pluseq_op:void) | +-(language::name:i:Z) | `-(language::integer:1:Z) - `-(language::pluseq_op:Z) + `-(language::pluseq_op:void) +-(language::name:i:Z) `-(language::integer:2:Z) )"; diff --git a/tests/test_ASTNodeDataTypeChecker.cpp b/tests/test_ASTNodeDataTypeChecker.cpp index e016e1087281b3175eba0076cb21aaab7234c89b..60dfc787a65b6af66f3463fe0871bb0b0a79116d 100644 --- a/tests/test_ASTNodeDataTypeChecker.cpp +++ b/tests/test_ASTNodeDataTypeChecker.cpp @@ -45,7 +45,7 @@ for(let i:Z, i=0; i<10; ++i) { ASTSymbolTableBuilder{*ast}; ASTNodeDataTypeBuilder{*ast}; - ast->children[0]->m_data_type = ASTNodeDataType::undefined_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::undefined_t>(); REQUIRE_THROWS_AS(ASTNodeDataTypeChecker{*ast}, ParseError); } diff --git a/tests/test_ASTNodeExpressionBuilder.cpp b/tests/test_ASTNodeExpressionBuilder.cpp index 956b5a2527e8e80bb1cef2287f9b2c72a16768aa..3b09a0bb554e6692fe1f741229e9ea43725782af 100644 --- a/tests/test_ASTNodeExpressionBuilder.cpp +++ b/tests/test_ASTNodeExpressionBuilder.cpp @@ -373,7 +373,10 @@ let n:N; SECTION("unary not") { - CHECK_AST_THROWS_WITH(R"(not 1;)", "invalid implicit conversion: Z -> B"); + std::string error_message = R"(undefined unary operator +note: unexpected operand type Z)"; + + CHECK_AST_THROWS_WITH(R"(not 1;)", error_message); } SECTION("pre-increment operator") @@ -486,7 +489,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, long>) + `-(language::multiply_op:BinaryExpressionProcessor<language::multiply_op, long, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -502,7 +505,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, long>) + `-(language::divide_op:BinaryExpressionProcessor<language::divide_op, long, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -518,7 +521,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, long>) + `-(language::plus_op:BinaryExpressionProcessor<language::plus_op, long, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -534,7 +537,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long>) + `-(language::minus_op:BinaryExpressionProcessor<language::minus_op, long, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -544,17 +547,26 @@ x[2]; SECTION("or") { - CHECK_AST_THROWS_WITH(R"(1 or 2;)", "invalid implicit conversion: Z -> B"); + const std::string error_message = R"(undefined binary operator +note: incompatible operand types Z and Z)"; + + CHECK_AST_THROWS_WITH(R"(1 or 2;)", error_message); } SECTION("and") { - CHECK_AST_THROWS_WITH(R"(1 and 2;)", "invalid implicit conversion: Z -> B"); + const std::string error_message = R"(undefined binary operator +note: incompatible operand types Z and Z)"; + + CHECK_AST_THROWS_WITH(R"(1 and 2;)", error_message); } SECTION("xor") { - CHECK_AST_THROWS_WITH(R"(1 xor 2;)", "invalid implicit conversion: Z -> B"); + const std::string error_message = R"(undefined binary operator +note: incompatible operand types Z and Z)"; + + CHECK_AST_THROWS_WITH(R"(1 xor 2;)", error_message); } SECTION("lesser") @@ -565,7 +577,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, long, long>) + `-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -581,7 +593,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, long, long>) + `-(language::lesser_or_eq_op:BinaryExpressionProcessor<language::lesser_or_eq_op, bool, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -597,7 +609,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, long, long>) + `-(language::greater_op:BinaryExpressionProcessor<language::greater_op, bool, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -613,7 +625,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, long, long>) + `-(language::greater_or_eq_op:BinaryExpressionProcessor<language::greater_or_eq_op, bool, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -629,7 +641,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, long, long>) + `-(language::eqeq_op:BinaryExpressionProcessor<language::eqeq_op, bool, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -645,7 +657,7 @@ x[2]; std::string result = R"( (root:ASTNodeListProcessor) - `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, long, long>) + `-(language::not_eq_op:BinaryExpressionProcessor<language::not_eq_op, bool, long, long>) +-(language::integer:1:ValueProcessor) `-(language::integer:2:ValueProcessor) )"; @@ -763,7 +775,7 @@ for(let i:N, i=0; i<10; ++i); +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>) | +-(language::name:i:NameProcessor) | `-(language::integer:0:ValueProcessor) - +-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, unsigned long, long>) + +-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, unsigned long, long>) | +-(language::name:i:NameProcessor) | `-(language::integer:10:ValueProcessor) +-(language::unary_plusplus:IncDecExpressionProcessor<language::unary_plusplus, unsigned long>) @@ -788,7 +800,7 @@ for(; i<10; ++i); | `-(language::integer:0:ValueProcessor) `-(language::for_statement:ForProcessor) +-(language::for_init:FakeProcessor) - +-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, unsigned long, long>) + +-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, unsigned long, long>) | +-(language::name:i:NameProcessor) | `-(language::integer:10:ValueProcessor) +-(language::unary_plusplus:IncDecExpressionProcessor<language::unary_plusplus, unsigned long>) @@ -832,7 +844,7 @@ for(let i:N, i=0; i<10;); +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>) | +-(language::name:i:NameProcessor) | `-(language::integer:0:ValueProcessor) - +-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, unsigned long, long>) + +-(language::lesser_op:BinaryExpressionProcessor<language::lesser_op, bool, unsigned long, long>) | +-(language::name:i:NameProcessor) | `-(language::integer:10:ValueProcessor) +-(language::for_post:FakeProcessor) diff --git a/tests/test_ASTNodeFunctionExpressionBuilder.cpp b/tests/test_ASTNodeFunctionExpressionBuilder.cpp index 3ce9d041604f5a10cbe9ea97f063c3d24a1fd04d..fd0163b7aed4f94bf99bee19eec4ae46ed9b7f3b 100644 --- a/tests/test_ASTNodeFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeFunctionExpressionBuilder.cpp @@ -53,10 +53,25 @@ \ ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ ASTNodeTypeCleaner<language::fct_declaration>{*ast}; \ - REQUIRE_THROWS_AS(ASTNodeExpressionBuilder{*ast}, ParseError); \ + REQUIRE_THROWS_AS(ASTNodeExpressionBuilder{*ast}, ParseError); \ } -#define CHECK_AST_THROWS_WITH(data, error) \ +#define CHECK_TYPE_BUILDER_THROWS_WITH(data, error) \ + { \ + static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ + static_assert(std::is_same_v<std::decay_t<decltype(error)>, std::string>); \ + \ + string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + ASTModulesImporter{*ast}; \ + ASTNodeTypeCleaner<language::import_instruction>{*ast}; \ + \ + ASTSymbolTableBuilder{*ast}; \ + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, error); \ + } + +#define CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, error) \ { \ static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \ static_assert(std::is_same_v<std::decay_t<decltype(error)>, std::string>); \ @@ -382,6 +397,60 @@ f(x); CHECK_AST(data, result); } + SECTION("Return R^1x1 -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> x+x; +let x : R^1x1, x = 1; +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return R^2x2 -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x+x; +let x : R^2x2, x = (1,2,3,4); +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return R^3x3 -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x+x; +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return scalar -> R^1") { std::string_view data = R"( @@ -438,6 +507,73 @@ f(1,2,3); CHECK_AST(data, result); } + SECTION("Return scalar -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> x+1; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return tuple -> R^2x2") + { + std::string_view data = R"( +let f : R*R*R*R -> R^2x2, (x,y,z,t) -> (x,y,z,t); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 2ul>) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return tuple -> R^3x3") + { + std::string_view data = R"( +let f : R^3*R^3*R^3 -> R^3x3, (x,y,z) -> (x[0],x[1],x[2],y[0],y[1],y[2],z[0],z[1],z[2]); +f((1,2,3),(4,5,6),(7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 3ul>) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | `-(language::integer:3:ValueProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | `-(language::integer:6:ValueProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return '0' -> R^1") { std::string_view data = R"( @@ -489,6 +625,57 @@ f(1); CHECK_AST(data, result); } + SECTION("Return '0' -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<1ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return '0' -> R^2x2") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<2ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return '0' -> R^3x3") + { + std::string_view data = R"( +let f : R -> R^3x3, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<3ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return embedded R^d compound") { std::string_view data = R"( @@ -510,6 +697,27 @@ f(1,2,3,4); CHECK_AST(data, result); } + SECTION("Return embedded R^dxd compound") + { + std::string_view data = R"( +let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, (x), (x,y,z,t), (x,y,z, x,x,x, t,t,t)); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return embedded R^d compound with '0'") { std::string_view data = R"( @@ -531,6 +739,27 @@ f(1,2,3,4); CHECK_AST(data, result); } + SECTION("Return embedded R^dxd compound with '0'") + { + std::string_view data = R"( +let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, 0, 0, (x, y, z, t, x, y, z, t, x)); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments '0' -> R^1") { std::string_view data = R"( @@ -582,6 +811,57 @@ f(0); CHECK_AST(data, result); } + SECTION("Arguments '0' -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Arguments '0' -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Arguments '0' -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments tuple -> R^d") { std::string_view data = R"( @@ -602,11 +882,37 @@ f((1,2,3)); CHECK_AST(data, result); } + SECTION("Arguments tuple -> R^dxd") + { + std::string_view data = R"( +let f: R^3x3 -> R, x -> x[0,0]+x[0,1]+x[0,2]; +f((1,2,3,4,5,6,7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::integer:4:ValueProcessor) + +-(language::integer:5:ValueProcessor) + +-(language::integer:6:ValueProcessor) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments compound with tuple") { std::string_view data = R"( -let f: R*R^3*R^2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0]+y[1]; -f(2,(1,2,3),(2,1.3)); +let f: R*R^3*R^2x2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0,0]+y[1,1]; +f(2,(1,2,3),(2,3,-1,1.3)); )"; std::string_view result = R"( @@ -621,6 +927,9 @@ f(2,(1,2,3),(2,1.3)); | `-(language::integer:3:ValueProcessor) `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>) + | `-(language::integer:1:ValueProcessor) `-(language::real:1.3:ValueProcessor) )"; @@ -656,60 +965,54 @@ sum(2); { std::string_view data = R"( let bad_conv : string -> R, s -> s; -bad_conv(2); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: string -> R"}); + CHECK_TYPE_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: string -> R"}); } SECTION("R -> B") { std::string_view data = R"( let bad_B : R -> B, x -> x; -bad_B(2); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> B"}); + CHECK_TYPE_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> B"}); } SECTION("R -> N") { std::string_view data = R"( let next : R -> N, x -> x; -next(6); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> N"}); + CHECK_TYPE_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> N"}); } SECTION("R -> Z") { std::string_view data = R"( let prev : R -> Z, x -> x; -prev(-3); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"}); + CHECK_TYPE_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"}); } SECTION("N -> B") { std::string_view data = R"( let bad_B : N -> B, n -> n; -bad_B(3); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> B"}); + CHECK_TYPE_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> B"}); } SECTION("Z -> B") { std::string_view data = R"( let bad_B : Z -> B, n -> n; -bad_B(3); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> B"}); + CHECK_TYPE_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> B"}); } } @@ -723,7 +1026,7 @@ let n : N, n = 2; negate(n); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> B"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> B"}); } SECTION("Z -> B") @@ -733,7 +1036,7 @@ let negate : B -> B, b -> not b; negate(3-4); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> B"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> B"}); } SECTION("R -> B") @@ -743,7 +1046,7 @@ let negate : B -> B, b -> not b; negate(3.24); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> B"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> B"}); } SECTION("R -> N") @@ -753,7 +1056,7 @@ let next : N -> N, n -> n+1; next(3.24); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> N"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> N"}); } SECTION("R -> Z") @@ -763,7 +1066,7 @@ let prev : Z -> Z, z -> z-1; prev(3 + .24); )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"}); } } @@ -776,7 +1079,9 @@ let f : R^2 -> R, x->x[0]; f((1,2,3)); )"; - CHECK_AST_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 2, but provided 3"}); } SECTION("tuple[2] -> R^3") @@ -786,7 +1091,9 @@ let f : R^3 -> R, x->x[0]; f((1,2)); )"; - CHECK_AST_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 3, but provided 2"}); } SECTION("compound tuple[3] -> R^2") @@ -796,7 +1103,9 @@ let f : R*R^2 -> R, (t,x)->x[0]; f(1,(1,2,3)); )"; - CHECK_AST_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 2, but provided 3"}); } SECTION("compound tuple[2] -> R^3") @@ -806,7 +1115,9 @@ let f : R^3*R^2 -> R, (x,y)->x[0]*y[1]; f((1,2),(3,4)); )"; - CHECK_AST_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, + std::string{ + "incompatible dimensions in affectation: expecting 3, but provided 2"}); } SECTION("list instead of tuple -> R^3") @@ -816,7 +1127,7 @@ let f : R^3 -> R, x -> x[0]*x[1]; f(1,2,3); )"; - CHECK_AST_THROWS_WITH(data, std::string{"bad number of arguments: expecting 1, provided 3"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"bad number of arguments: expecting 1, provided 3"}); } SECTION("list instead of tuple -> R^3*R^2") @@ -826,7 +1137,7 @@ let f : R^3*R^2 -> R, (x,y) -> x[0]*x[1]-y[0]; f((1,2,3),2,3); )"; - CHECK_AST_THROWS_WITH(data, std::string{"bad number of arguments: expecting 2, provided 3"}); + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"bad number of arguments: expecting 2, provided 3"}); } } } diff --git a/tests/test_ASTNodeIncDecExpressionBuilder.cpp b/tests/test_ASTNodeIncDecExpressionBuilder.cpp index 60ce2489fd8b1b93996cd2904b8bf9852b8a1c79..2b64fcd03c2aeedb6ab3902cf374791205344815 100644 --- a/tests/test_ASTNodeIncDecExpressionBuilder.cpp +++ b/tests/test_ASTNodeIncDecExpressionBuilder.cpp @@ -8,6 +8,7 @@ #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> #include <language/utils/ASTPrinter.hpp> +#include <language/utils/OperatorRepository.hpp> #include <utils/Demangle.hpp> #include <pegtl/string_input.hpp> @@ -299,22 +300,23 @@ x--; SECTION("Errors") { - SECTION("Invalid operator type") + SECTION("Undefined operator") { - auto ast = std::make_unique<ASTNode>(); - REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast}, - "unexpected error: undefined increment/decrement operator"); + auto& operator_repository = OperatorRepository::instance(); + auto optional_value_type = operator_repository.getIncDecOperatorValueType("string ++"); + REQUIRE(not optional_value_type.has_value()); } SECTION("Invalid operand type") { auto ast = std::make_unique<ASTNode>(); ast->set_type<language::unary_plusplus>(); - ast->m_data_type = ASTNodeDataType::undefined_t; + ast->m_data_type = ASTNodeDataType::build<ASTNodeDataType::undefined_t>(); ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast}, "invalid operand type for unary operator"); + REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast}, + "invalid operand type. ++/-- operators only apply to variables"); } SECTION("Invalid data type") @@ -325,8 +327,7 @@ x--; ast->children.emplace_back(std::make_unique<ASTNode>()); ast->children[0]->set_type<language::name>(); - REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast}, - "unexpected error: undefined data type for unary operator"); + REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast}, "undefined affectation type: ++ undefined"); } SECTION("Not allowed chained ++/--") @@ -337,7 +338,7 @@ x--; 1 ++ ++; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -348,7 +349,7 @@ x--; 1 ++ --; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -359,7 +360,7 @@ x--; 1 -- ++; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -370,7 +371,7 @@ x--; 1 -- --; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -381,7 +382,7 @@ x--; ++ ++ 1; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -392,7 +393,7 @@ x--; ++ -- 1; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -403,7 +404,7 @@ x--; -- ++ 1; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -414,7 +415,7 @@ x--; -- -- 1; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -425,7 +426,7 @@ x--; ++ 1 ++; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -436,7 +437,7 @@ x--; ++ 1 --; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -447,7 +448,7 @@ x--; -- 1 ++; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } @@ -458,7 +459,7 @@ x--; -- 1 --; )"; - std::string error_message = R"(chaining ++ or -- operators is not allowed)"; + std::string error_message = R"(invalid operand type. ++/-- operators only apply to variables)"; DISALLOWED_CHAINED_AST(data, error_message) } diff --git a/tests/test_ASTNodeJumpPlacementChecker.cpp b/tests/test_ASTNodeJumpPlacementChecker.cpp index 631b172775f381405bbc8ff830d19aad92b9c8e5..040c2db1c3d7b0ee3be8c663a5c701827e230fb2 100644 --- a/tests/test_ASTNodeJumpPlacementChecker.cpp +++ b/tests/test_ASTNodeJumpPlacementChecker.cpp @@ -75,7 +75,7 @@ do { ASTSymbolTableBuilder{*ast}; ASTNodeDataTypeBuilder{*ast}; - ast->children[0]->m_data_type = ASTNodeDataType::undefined_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::undefined_t>(); REQUIRE_THROWS_AS(ASTNodeJumpPlacementChecker{*ast}, ParseError); } @@ -144,7 +144,7 @@ do { ASTSymbolTableBuilder{*ast}; ASTNodeDataTypeBuilder{*ast}; - ast->children[0]->m_data_type = ASTNodeDataType::undefined_t; + ast->children[0]->m_data_type = ASTNodeDataType::build<ASTNodeDataType::undefined_t>(); REQUIRE_THROWS_AS(ASTNodeJumpPlacementChecker{*ast}, ParseError); } diff --git a/tests/test_ASTNodeListAffectationExpressionBuilder.cpp b/tests/test_ASTNodeListAffectationExpressionBuilder.cpp index 333b98bd998124e676c885025ef1d646ba517829..62b584eaa1f9e0bac8b05ae8f8d1072cf4e46bae 100644 --- a/tests/test_ASTNodeListAffectationExpressionBuilder.cpp +++ b/tests/test_ASTNodeListAffectationExpressionBuilder.cpp @@ -199,6 +199,56 @@ let (x1,x2,x3,x) : R^1*R^2*R^3*R, CHECK_AST(data, result); } + SECTION("without conversion R^1x1*R^2x2*R^3x3*R") + { + std::string_view data = R"( +let a:R^1x1, a = 0; +let b:R^2x2, b = (1, 2, 3, 4); +let c:R^3x3, c = (9, 8, 7, 6, 5, 4, 3, 2, 1); +let (x1,x2,x3,x) : R^1x1*R^2x2*R^3x3*R, + (x1,x2,x3,x) = (a, b, c, 2); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + +-(language::eq_op:AffectationProcessor<language::eq_op, TinyMatrix<1ul, double>, long>) + | +-(language::name:a:NameProcessor) + | `-(language::integer:0:ValueProcessor) + +-(language::eq_op:AffectationToTinyMatrixFromListProcessor<language::eq_op, TinyMatrix<2ul, double> >) + | +-(language::name:b:NameProcessor) + | `-(language::expression_list:ASTNodeExpressionListProcessor) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | `-(language::integer:4:ValueProcessor) + +-(language::eq_op:AffectationToTinyMatrixFromListProcessor<language::eq_op, TinyMatrix<3ul, double> >) + | +-(language::name:c:NameProcessor) + | `-(language::expression_list:ASTNodeExpressionListProcessor) + | +-(language::integer:9:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | `-(language::integer:1:ValueProcessor) + `-(language::eq_op:ListAffectationProcessor<language::eq_op>) + +-(language::name_list:FakeProcessor) + | +-(language::name:x1:NameProcessor) + | +-(language::name:x2:NameProcessor) + | +-(language::name:x3:NameProcessor) + | `-(language::name:x:NameProcessor) + `-(language::expression_list:ASTNodeExpressionListProcessor) + +-(language::name:a:NameProcessor) + +-(language::name:b:NameProcessor) + +-(language::name:c:NameProcessor) + `-(language::integer:2:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Zero initialization") { std::string_view data = R"( @@ -223,6 +273,30 @@ let (x1,x2,x3,x) : R^1*R^2*R^3*R, (x1,x2,x3,x) = (0, 0, 0, 0); CHECK_AST(data, result); } + SECTION("Zero initialization") + { + std::string_view data = R"( +let (x1,x2,x3,x) : R^1x1*R^2x2*R^3x3*R, (x1,x2,x3,x) = (0, 0, 0, 0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::eq_op:ListAffectationProcessor<language::eq_op>) + +-(language::name_list:FakeProcessor) + | +-(language::name:x1:NameProcessor) + | +-(language::name:x2:NameProcessor) + | +-(language::name:x3:NameProcessor) + | `-(language::name:x:NameProcessor) + `-(language::expression_list:ASTNodeExpressionListProcessor) + +-(language::integer:0:ValueProcessor) + +-(language::integer:0:ValueProcessor) + +-(language::integer:0:ValueProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("from function") { std::string_view data = R"( @@ -385,7 +459,7 @@ let x:R^2, x = (1,2); let y:R^3, y = x; )"; - CHECK_AST_THROWS_WITH(data, std::string{"incompatible dimensions in affectation"}); + CHECK_AST_THROWS_WITH(data, std::string{"undefined affectation type: R^3 = R^2"}); } SECTION("invalid Z -> R^d conversion (non-zero)") @@ -394,7 +468,7 @@ let y:R^3, y = x; let x:R^2, x = 1; )"; - CHECK_AST_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^2"}); + CHECK_AST_THROWS_WITH(data, std::string{"invalid integral value (0 is the solely valid value)"}); } } } diff --git a/tests/test_ASTNodeNaturalConversionChecker.cpp b/tests/test_ASTNodeNaturalConversionChecker.cpp index fe29cf9d0148d7a4d0a31bd71b97d84820e63299..a9e10344d3f6eaacd1ac6fde102b1d4b82214927 100644 --- a/tests/test_ASTNodeNaturalConversionChecker.cpp +++ b/tests/test_ASTNodeNaturalConversionChecker.cpp @@ -2,7 +2,7 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> -#include <language/ast/ASTNodeNaturalConversionChecker.hpp> +#include <language/utils/ASTNodeNaturalConversionChecker.hpp> namespace language { @@ -13,6 +13,22 @@ struct integer; TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { + const ASTNodeDataType undefined_dt = ASTNodeDataType{}; + const ASTNodeDataType bool_dt = ASTNodeDataType::build<ASTNodeDataType::bool_t>(); + const ASTNodeDataType unsigned_int_dt = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); + const ASTNodeDataType int_dt = ASTNodeDataType::build<ASTNodeDataType::int_t>(); + const ASTNodeDataType double_dt = ASTNodeDataType::build<ASTNodeDataType::double_t>(); + const ASTNodeDataType string_dt = ASTNodeDataType::build<ASTNodeDataType::string_t>(); + const ASTNodeDataType void_dt = ASTNodeDataType::build<ASTNodeDataType::void_t>(); + const ASTNodeDataType function_dt = ASTNodeDataType::build<ASTNodeDataType::function_t>(); + const ASTNodeDataType builtin_function_dt = ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>(); + + std::vector<std::shared_ptr<const ASTNodeDataType>> type_list; + type_list.push_back(std::make_shared<const ASTNodeDataType>(double_dt)); + type_list.push_back(std::make_shared<const ASTNodeDataType>(int_dt)); + + const ASTNodeDataType list_dt = ASTNodeDataType::build<ASTNodeDataType::list_t>(type_list); + SECTION("Valid conversions") { std::unique_ptr data_node = std::make_unique<ASTNode>(); @@ -21,50 +37,189 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("string -> string") { - data_node->m_data_type = ASTNodeDataType::string_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = string_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("R^d -> string") { - data_node->m_data_type = ASTNodeDataType::vector_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(5); + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("R -> string") { - data_node->m_data_type = ASTNodeDataType::double_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = double_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("Z -> string") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("N -> string") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = unsigned_int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("B -> string") { - data_node->m_data_type = ASTNodeDataType::bool_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = bool_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("list -> string") { - data_node->m_data_type = ASTNodeDataType::list_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = list_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); } SECTION("tuple -> string") { - data_node->m_data_type = ASTNodeDataType::tuple_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::string_t}); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, string_dt}); + } + } + + SECTION("-> R^dxd") + { + SECTION("R^1x1 -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}); + } + + SECTION("list -> R^1x1") + { + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + } + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}); + } + + SECTION("'0' -> R^dxd") + { + data_node->m_data_type = int_dt; + data_node->set_type<language::integer>(); + data_node->source = "0"; + auto& source = data_node->source; + data_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]}; + data_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]}; + + SECTION("d = 1") + { + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}); + } + SECTION("d = 2") + { + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}); + } + SECTION("d = 3") + { + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}); + } + } + + SECTION("R^2x2 -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}); + } + + SECTION("list -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + + std::unique_ptr list2_node = std::make_unique<ASTNode>(); + list2_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list2_node)); + + std::unique_ptr list3_node = std::make_unique<ASTNode>(); + list3_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list3_node)); + } + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}); + } + + SECTION("R^3x3 -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}); + } + + SECTION("list -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt), std::make_shared<const ASTNodeDataType>(double_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt), std::make_shared<const ASTNodeDataType>(int_dt), + std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + + std::unique_ptr list2_node = std::make_unique<ASTNode>(); + list2_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list2_node)); + + std::unique_ptr list3_node = std::make_unique<ASTNode>(); + list3_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list3_node)); + + std::unique_ptr list4_node = std::make_unique<ASTNode>(); + list4_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list4_node)); + + std::unique_ptr list5_node = std::make_unique<ASTNode>(); + list5_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list5_node)); + + std::unique_ptr list6_node = std::make_unique<ASTNode>(); + list6_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list6_node)); + + std::unique_ptr list7_node = std::make_unique<ASTNode>(); + list7_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list7_node)); + + std::unique_ptr list8_node = std::make_unique<ASTNode>(); + list8_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list8_node)); + } + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}); } } @@ -72,24 +227,27 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("R^1 -> R^1") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 1}}); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}); } SECTION("list -> R^1") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt)}); { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); } - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 1}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}); } SECTION("'0' -> R^d") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::int_t, 1}; + data_node->m_data_type = int_dt; data_node->set_type<language::integer>(); data_node->source = "0"; auto& source = data_node->source; @@ -98,62 +256,73 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("d = 1") { - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 1}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}); } SECTION("d = 2") { - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 2}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}); } SECTION("d = 3") { - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 3}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}); } } SECTION("R^2 -> R^2") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 2}}); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}); } SECTION("list -> R^2") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt), + std::make_shared<const ASTNodeDataType>(unsigned_int_dt)}); { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); } - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 2}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}); } SECTION("R^3 -> R^3") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 3}}); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}); } SECTION("list -> R^3") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt)}); { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::vector_t, 3}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}); } } @@ -161,26 +330,26 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("R -> R") { - data_node->m_data_type = ASTNodeDataType::double_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}); + data_node->m_data_type = double_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, double_dt}); } SECTION("Z -> R") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}); + data_node->m_data_type = int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, double_dt}); } SECTION("N -> R") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}); + data_node->m_data_type = unsigned_int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, double_dt}); } SECTION("B -> R") { - data_node->m_data_type = ASTNodeDataType::bool_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}); + data_node->m_data_type = bool_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, double_dt}); } } @@ -188,20 +357,20 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("Z -> Z") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}); + data_node->m_data_type = int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, int_dt}); } SECTION("N -> Z") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}); + data_node->m_data_type = unsigned_int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, int_dt}); } SECTION("B -> Z") { - data_node->m_data_type = ASTNodeDataType::bool_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}); + data_node->m_data_type = bool_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, int_dt}); } } @@ -209,20 +378,20 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("Z -> N") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}); + data_node->m_data_type = int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}); } SECTION("N -> N") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}); + data_node->m_data_type = unsigned_int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}); } SECTION("B -> N") { - data_node->m_data_type = ASTNodeDataType::bool_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}); + data_node->m_data_type = bool_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}); } } @@ -230,8 +399,8 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("B -> B") { - data_node->m_data_type = ASTNodeDataType::bool_t; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}); + data_node->m_data_type = bool_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, bool_dt}); } } @@ -239,216 +408,204 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("B -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType::bool_t; + data_node->m_data_type = bool_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::bool_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt)}); } SECTION("B -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::bool_t; - REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}}); + data_node->m_data_type = bool_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + unsigned_int_dt)}); } SECTION("N -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}}); + data_node->m_data_type = unsigned_int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + unsigned_int_dt)}); } SECTION("Z -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}}); + data_node->m_data_type = int_dt; + REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + unsigned_int_dt)}); } SECTION("B -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType::bool_t; + data_node->m_data_type = bool_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::int_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt)}); } SECTION("N -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; + data_node->m_data_type = unsigned_int_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::int_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt)}); } SECTION("Z -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType::int_t; + data_node->m_data_type = int_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::int_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt)}); } SECTION("B -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType::bool_t; + data_node->m_data_type = bool_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::double_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt)}); } SECTION("N -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; + data_node->m_data_type = unsigned_int_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::double_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt)}); } SECTION("Z -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType::int_t; + data_node->m_data_type = int_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::double_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt)}); } SECTION("R -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType::double_t; + data_node->m_data_type = double_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::double_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt)}); } SECTION("R^1 -> tuple(R^1)") { - auto R1 = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; + auto R1 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); data_node->m_data_type = R1; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, R1}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R1)}); } SECTION("R^2 -> tuple(R^2)") { - auto R2 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; + auto R2 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); data_node->m_data_type = R2; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, R2}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2)}); } SECTION("R^3 -> tuple(R^3)") { - auto R3 = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; + auto R3 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); data_node->m_data_type = R3; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, R3}}); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3)}); } SECTION("string -> tuple(string)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::string_t}; + data_node->m_data_type = string_dt; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ASTNodeDataType::string_t}}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt)}); } SECTION("type_id_t -> tuple(type_id_t)") { - auto type_id = ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}; + auto type_id = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"); data_node->m_data_type = type_id; REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, type_id}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id)}); } SECTION("(B, B, B) -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::bool_t; + list0_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::bool_t; + list1_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::bool_t; + list2_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(B, N, Z) -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::bool_t; + list0_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(B, N, Z) -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::bool_t; + list0_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(R, N, Z) -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(R^1, R^1) -> tuple(R^1)") { - auto R1 = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - data_node->m_data_type = ASTNodeDataType::list_t; + auto R1 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); list0_node->m_data_type = R1; @@ -458,14 +615,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") list1_node->m_data_type = R1; data_node->emplace_back(std::move(list1_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, R1}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R1); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(R^2, R^2, R^2) -> tuple(R^2)") { - auto R2 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - data_node->m_data_type = ASTNodeDataType::list_t; + auto R2 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); list0_node->m_data_type = R2; @@ -479,14 +636,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") list2_node->m_data_type = R2; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, R2}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(R^3, R^3) -> tuple(R^3)") { - auto R3 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - data_node->m_data_type = ASTNodeDataType::list_t; + auto R3 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); list0_node->m_data_type = R3; @@ -496,14 +653,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") list1_node->m_data_type = R3; data_node->emplace_back(std::move(list1_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, R3}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(string, string) -> tuple(string)") { - auto str_t = ASTNodeDataType{ASTNodeDataType::string_t}; - data_node->m_data_type = ASTNodeDataType::list_t; + auto str_t = string_dt; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); list0_node->m_data_type = str_t; @@ -513,14 +670,14 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") list1_node->m_data_type = str_t; data_node->emplace_back(std::move(list1_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, str_t}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(str_t); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("(type_id_t, type_id_t) -> tuple(type_id_t)") { - auto type_id = ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}; - data_node->m_data_type = ASTNodeDataType::list_t; + auto type_id = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"); + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); list0_node->m_data_type = type_id; @@ -530,169 +687,168 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") list1_node->m_data_type = type_id; data_node->emplace_back(std::move(list1_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, type_id}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(B) -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(B) -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(Z) -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(N) -> tuple(N)") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(B) -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(N) -> tuple(Z)") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(Z) -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(B) -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(Z) -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(N) -> tuple(R)") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(R) -> tuple(R)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(B) -> tuple(string)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(Z) -> tuple(string)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(N) -> tuple(string)") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(R) -> tuple(string)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(string) -> tuple(string)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, tuple_t}); } SECTION("tuple(R^1) -> tuple(R^1)") { - auto R1 = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, R1}; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, R1}}); + auto R1 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R1); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R1)}); } SECTION("tuple(R^2) -> tuple(R^2)") { - auto R2 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, R2}; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, R2}}); + auto R2 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2)}); } SECTION("tuple(R^3) -> tuple(R^3)") { - auto R3 = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, R3}; - REQUIRE_NOTHROW(ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, R3}}); + auto R3 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3); + REQUIRE_NOTHROW( + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3)}); } SECTION("tuple(type_id_t) -> tuple(type_id_t)") { - auto type_id = ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}; - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, type_id}; + auto type_id = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"); + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id); REQUIRE_NOTHROW( - ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, type_id}}); + ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id)}); } } } @@ -701,137 +857,557 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { std::unique_ptr data_node = std::make_unique<ASTNode>(); + SECTION("-> R^dxd") + { + SECTION("R^2x2 -> R^1x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}), + "invalid implicit conversion: R^2x2 -> R^1x1"); + } + + SECTION("R^3x3 -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)}), + "invalid implicit conversion: R^3x3 -> R^1x1"); + } + + SECTION("R^1x1 -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}), + "invalid implicit conversion: R^1x1 -> R^2x2"); + } + + SECTION("R^3x3 -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)}), + "invalid implicit conversion: R^3x3 -> R^2x2"); + } + + SECTION("R^1x1 -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}), + "invalid implicit conversion: R^1x1 -> R^3x3"); + } + + SECTION("R^2x2 -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)}), + "invalid implicit conversion: R^2x2 -> R^3x3"); + } + + SECTION("list1 -> R^dxd") + { + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "incompatible dimensions in affectation: expecting 4, but provided 1"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "incompatible dimensions in affectation: expecting 9, but provided 1"); + } + } + + SECTION("list2 -> R^dxd") + { + data_node->m_data_type = list_dt; + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + } + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "incompatible dimensions in affectation: expecting 1, but provided 2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "incompatible dimensions in affectation: expecting 9, but provided 2"); + } + } + + SECTION("list3 -> R^dxd") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt)}); + { + std::unique_ptr list0_node = std::make_unique<ASTNode>(); + list0_node->m_data_type = double_dt; + data_node->emplace_back(std::move(list0_node)); + + std::unique_ptr list1_node = std::make_unique<ASTNode>(); + list1_node->m_data_type = unsigned_int_dt; + data_node->emplace_back(std::move(list1_node)); + + std::unique_ptr list2_node = std::make_unique<ASTNode>(); + list2_node->m_data_type = int_dt; + data_node->emplace_back(std::move(list2_node)); + } + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "incompatible dimensions in affectation: expecting 1, but provided 3"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "incompatible dimensions in affectation: expecting 4, but provided 3"); + } + } + + SECTION("tuple -> R^dxd") + { + SECTION("tuple(N) -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: tuple(N) -> R^1x1"); + } + + SECTION("tuple(R) -> R^1x1") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: tuple(R) -> R^1x1"); + } + + SECTION("tuple(R) -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: tuple(R) -> R^2x2"); + } + + SECTION("tuple(B) -> R^2x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: tuple(B) -> R^2x2"); + } + + SECTION("tuple(Z) -> R^3x2") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: tuple(Z) -> R^3x3"); + } + + SECTION("tuple(R) -> R^3x3") + { + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: tuple(R) -> R^3x3"); + } + + SECTION("tuple(R^1) -> tuple(R^3x3)") + { + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^1 -> R^3x3"); + } + + SECTION("tuple(R^2) -> tuple(R^3x3)") + { + auto tuple_R2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^2 -> R^3x3"); + } + + SECTION("tuple(R^2) -> tuple(R^1x1)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + data_node->m_data_type = tuple_R2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1x1}), + "invalid implicit conversion: R^2 -> R^1x1"); + } + + SECTION("tuple(R^1x1) -> tuple(R^3x3)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R1x1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^1x1 -> R^3x3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^3x3)") + { + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + auto tuple_R3x3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3x3}), + "invalid implicit conversion: R^2x2 -> R^3x3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^1x1)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1x1}), + "invalid implicit conversion: R^2x2 -> R^1x1"); + } + } + + SECTION("R -> R^dxd") + { + data_node->m_data_type = double_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: R -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: R -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: R -> R^3x3"); + } + } + + SECTION("Z -> R^dxd (non-zero)") + { + data_node->m_data_type = int_dt; + data_node->set_type<language::integer>(); + data_node->source = "1"; + auto& source = data_node->source; + data_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]}; + data_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]}; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: Z -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: Z -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: Z -> R^3x3"); + } + } + + SECTION("N -> R^dxd") + { + data_node->m_data_type = unsigned_int_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: N -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: N -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: N -> R^3x3"); + } + } + + SECTION("B -> R^dxd") + { + data_node->m_data_type = bool_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: B -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: B -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: B -> R^3x3"); + } + } + + SECTION("string -> R^dxd") + { + data_node->m_data_type = string_dt; + + SECTION("d=1") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, + 1)}), + "invalid implicit conversion: string -> R^1x1"); + } + + SECTION("d=2") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, + 2)}), + "invalid implicit conversion: string -> R^2x2"); + } + + SECTION("d=3") + { + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, + 3)}), + "invalid implicit conversion: string -> R^3x3"); + } + } + } + SECTION("-> R^d") { SECTION("R^2 -> R^1") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), + "invalid implicit conversion: R^2 -> R^1"); } SECTION("R^3 -> R^1") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), + "invalid implicit conversion: R^3 -> R^1"); } SECTION("R^1 -> R^2") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), + "invalid implicit conversion: R^1 -> R^2"); } SECTION("R^3 -> R^2") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), + "invalid implicit conversion: R^3 -> R^2"); } SECTION("R^1 -> R^3") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), + "invalid implicit conversion: R^1 -> R^3"); } SECTION("R^2 -> R^3") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), + "invalid implicit conversion: R^2 -> R^3"); } SECTION("list1 -> R^d") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = + ASTNodeDataType::build<ASTNodeDataType::list_t>({std::make_shared<const ASTNodeDataType>(double_dt)}); { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), + "incompatible dimensions in affectation: expecting 2, but provided 1"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), + "incompatible dimensions in affectation: expecting 3, but provided 1"); } } SECTION("list2 -> R^d") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); } SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), + "incompatible dimensions in affectation: expecting 1, but provided 2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), + "incompatible dimensions in affectation: expecting 3, but provided 2"); } } SECTION("list3 -> R^d") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::list_t>( + {std::make_shared<const ASTNodeDataType>(double_dt), std::make_shared<const ASTNodeDataType>(unsigned_int_dt), + std::make_shared<const ASTNodeDataType>(int_dt)}); { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), + "incompatible dimensions in affectation: expecting 1, but provided 3"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), + "incompatible dimensions in affectation: expecting 2, but provided 3"); } } @@ -839,60 +1415,91 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("tuple(N) -> R^1") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: tuple(N) -> R^1"); } SECTION("tuple(R) -> R^1") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: tuple(R) -> R^1"); } SECTION("tuple(R) -> R^2") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: tuple(R) -> R^2"); } SECTION("tuple(B) -> R^2") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: tuple(B) -> R^2"); } SECTION("tuple(Z) -> R^3") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: tuple(Z) -> R^3"); } SECTION("tuple(R) -> R^3") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: tuple(R) -> R^3"); } + SECTION("tuple(R^1x1) -> tuple(R^3)") + { + auto tuple_R1x1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)); + auto tuple_R3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + data_node->m_data_type = tuple_R1x1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3}), + "invalid implicit conversion: R^1x1 -> R^3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^3)") + { + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + auto tuple_R3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3}), + "invalid implicit conversion: R^2x2 -> R^3"); + } + + SECTION("tuple(R^2x2) -> tuple(R^1)") + { + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + auto tuple_R2x2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)); + data_node->m_data_type = tuple_R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1}), + "invalid implicit conversion: R^2x2 -> R^1"); + } + SECTION("tuple(R^1) -> tuple(R^3)") { - auto tuple_R1 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1}}; - auto tuple_R3 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 3}}; + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + auto tuple_R3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); data_node->m_data_type = tuple_R1; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3}), "invalid implicit conversion: R^1 -> R^3"); @@ -900,8 +1507,10 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(R^2) -> tuple(R^3)") { - auto tuple_R2 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 2}}; - auto tuple_R3 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 3}}; + auto tuple_R2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + auto tuple_R3 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)); data_node->m_data_type = tuple_R2; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R3}), "invalid implicit conversion: R^2 -> R^3"); @@ -909,8 +1518,10 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(R^2) -> tuple(R^1)") { - auto tuple_R1 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1}}; - auto tuple_R2 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 2}}; + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); + auto tuple_R2 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); data_node->m_data_type = tuple_R2; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1}), "invalid implicit conversion: R^2 -> R^1"); @@ -918,8 +1529,8 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(R) -> tuple(Z)") { - auto tuple_R = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; - auto tuple_Z = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + auto tuple_R = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + auto tuple_Z = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); data_node->m_data_type = tuple_R; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_Z}), "invalid implicit conversion: R -> Z"); @@ -927,8 +1538,9 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(R) -> tuple(R^1)") { - auto tuple_R = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; - auto tuple_R1 = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1}}; + auto tuple_R = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + auto tuple_R1 = + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)); data_node->m_data_type = tuple_R; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R1}), "invalid implicit conversion: R -> R^1"); @@ -936,8 +1548,8 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(string) -> tuple(R)") { - auto tuple_string = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::string_t}}; - auto tuple_R = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + auto tuple_string = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt); + auto tuple_R = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); data_node->m_data_type = tuple_string; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R}), "invalid implicit conversion: string -> R"); @@ -945,10 +1557,10 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(type_id) -> tuple(R)") { - auto type_id = ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}; + auto type_id = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"); - auto tuple_type_id = ASTNodeDataType{ASTNodeDataType::tuple_t, type_id}; - auto tuple_R = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; + auto tuple_type_id = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id); + auto tuple_R = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); data_node->m_data_type = tuple_type_id; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_R}), "invalid implicit conversion: foo -> R"); @@ -956,11 +1568,11 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("tuple(type_id) -> tuple(R)") { - auto type_id0 = ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}; - auto type_id1 = ASTNodeDataType{ASTNodeDataType::type_id_t, "bar"}; + auto type_id0 = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"); + auto type_id1 = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("bar"); - auto tuple_type_id0 = ASTNodeDataType{ASTNodeDataType::tuple_t, type_id0}; - auto tuple_type_id1 = ASTNodeDataType{ASTNodeDataType::tuple_t, type_id1}; + auto tuple_type_id0 = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id0); + auto tuple_type_id1 = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id1); data_node->m_data_type = tuple_type_id0; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_type_id1}), "invalid implicit conversion: foo -> bar"); @@ -969,33 +1581,33 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("R -> R^d") { - data_node->m_data_type = ASTNodeDataType::double_t; + data_node->m_data_type = double_dt; SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: R -> R^1"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: R -> R^2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: R -> R^3"); } } SECTION("Z -> R^d (non-zero)") { - data_node->m_data_type = ASTNodeDataType::int_t; + data_node->m_data_type = int_dt; data_node->set_type<language::integer>(); data_node->source = "1"; auto& source = data_node->source; @@ -1005,99 +1617,99 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: Z -> R^1"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: Z -> R^2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: Z -> R^3"); } } SECTION("N -> R^d") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; + data_node->m_data_type = unsigned_int_dt; SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: N -> R^1"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: N -> R^2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: N -> R^3"); } } SECTION("B -> R^d") { - data_node->m_data_type = ASTNodeDataType::bool_t; + data_node->m_data_type = bool_dt; SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: B -> R^1"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: B -> R^2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: B -> R^3"); } } SECTION("string -> R^d") { - data_node->m_data_type = ASTNodeDataType::string_t; + data_node->m_data_type = string_dt; SECTION("d=1") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 1}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1)}), "invalid implicit conversion: string -> R^1"); } SECTION("d=2") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 2}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)}), "invalid implicit conversion: string -> R^2"); } SECTION("d=3") { REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::vector_t, 3}}), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)}), "invalid implicit conversion: string -> R^3"); } } @@ -1107,44 +1719,43 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("string -> R") { - data_node->m_data_type = ASTNodeDataType::string_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}), + data_node->m_data_type = string_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, double_dt}), "invalid implicit conversion: string -> R"); } SECTION("R^1 -> R") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, double_dt}), "invalid implicit conversion: R^1 -> R"); } SECTION("R^2 -> R") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, double_dt}), "invalid implicit conversion: R^2 -> R"); } SECTION("R^3 -> R") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, double_dt}), "invalid implicit conversion: R^3 -> R"); } SECTION("tuple(N) -> R") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, double_dt}), "invalid implicit conversion: tuple(N) -> R"); } SECTION("tuple(R) -> R") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::double_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, double_dt}), "invalid implicit conversion: tuple(R) -> R"); } } @@ -1153,51 +1764,50 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("string -> Z") { - data_node->m_data_type = ASTNodeDataType::string_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = string_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: string -> Z"); } SECTION("R^1 -> Z") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: R^1 -> Z"); } SECTION("R^2 -> Z") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: R^2 -> Z"); } SECTION("R^3 -> Z") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: R^3 -> Z"); } SECTION("R -> Z") { - data_node->m_data_type = ASTNodeDataType::double_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = double_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: R -> Z"); } SECTION("tuple(N) -> Z") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: tuple(N) -> Z"); } SECTION("tuple(Z) -> Z") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, int_dt}), "invalid implicit conversion: tuple(Z) -> Z"); } } @@ -1206,51 +1816,50 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("string -> N") { - data_node->m_data_type = ASTNodeDataType::string_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = string_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: string -> N"); } SECTION("R^1 -> N") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: R^1 -> N"); } SECTION("R^2 -> N") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: R^2 -> N"); } SECTION("R^3 -> N") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: R^3 -> N"); } SECTION("R -> N") { - data_node->m_data_type = ASTNodeDataType::double_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = double_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: R -> N"); } SECTION("tuple(Z) -> N") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: tuple(Z) -> N"); } SECTION("tuple(N) -> N") { - data_node->m_data_type = - ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::unsigned_int_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, unsigned_int_dt}), "invalid implicit conversion: tuple(N) -> N"); } } @@ -1259,64 +1868,64 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("string -> B") { - data_node->m_data_type = ASTNodeDataType::string_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = string_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: string -> B"); } SECTION("R^1 -> B") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: R^1 -> B"); } SECTION("R^2 -> B") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: R^2 -> B"); } SECTION("R^3 -> B") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: R^3 -> B"); } SECTION("R -> B") { - data_node->m_data_type = ASTNodeDataType::double_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = double_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: R -> B"); } SECTION("Z -> B") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = int_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: Z -> B"); } SECTION("N -> B") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = unsigned_int_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: N -> B"); } SECTION("tuple(Z) -> B") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: tuple(Z) -> B"); } SECTION("tuple(B) -> B") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType::bool_t}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, bool_dt}), "invalid implicit conversion: tuple(B) -> B"); } } @@ -1325,166 +1934,195 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") { SECTION("N -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType::unsigned_int_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = unsigned_int_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: N -> B"); } SECTION("Z -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType::int_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = int_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: Z -> B"); } SECTION("R -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType::double_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = double_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: R -> B"); } SECTION("string -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType::string_t; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = string_dt; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: string -> B"); } SECTION("R^1 -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: R^1 -> B"); } SECTION("R^2 -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: R^2 -> B"); } SECTION("R^3 -> tuple(B)") { - data_node->m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::bool_t}}}), + data_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + bool_dt)}), "invalid implicit conversion: R^3 -> B"); } SECTION("R -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::double_t; + data_node->m_data_type = double_dt; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, - ASTNodeDataType{ - ASTNodeDataType::unsigned_int_t}}}), + ASTNodeDataType::build<ASTNodeDataType::tuple_t>( + unsigned_int_dt)}), "invalid implicit conversion: R -> N"); } + SECTION("R^1x1 -> tuple(R^2x2)") + { + auto R1x1 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1); + auto R2x2 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + data_node->m_data_type = R1x1; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2x2)}), + "invalid implicit conversion: R^1x1 -> R^2x2"); + } + + SECTION("R^2x2 -> tuple(R^3x3)") + { + auto R2x2 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + auto R3x3 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + data_node->m_data_type = R2x2; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3x3)}), + "invalid implicit conversion: R^2x2 -> R^3x3"); + } + + SECTION("R^3x3 -> tuple(R^2x2)") + { + auto R3x3 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3); + auto R2x2 = ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2); + data_node->m_data_type = R3x3; + REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2x2)}), + "invalid implicit conversion: R^3x3 -> R^2x2"); + } + SECTION("R^1 -> tuple(R^2)") { - auto R1 = ASTNodeDataType{ASTNodeDataType::vector_t, 1}; - auto R2 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; + auto R1 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(1); + auto R2 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); data_node->m_data_type = R1; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, R2}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2)}), + "invalid implicit conversion: R^1 -> R^2"); } SECTION("R^2 -> tuple(R^3)") { - auto R2 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; - auto R3 = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; + auto R2 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); + auto R3 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); data_node->m_data_type = R2; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, R3}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R3)}), + "invalid implicit conversion: R^2 -> R^3"); } SECTION("R^3 -> tuple(R^2)") { - auto R3 = ASTNodeDataType{ASTNodeDataType::vector_t, 3}; - auto R2 = ASTNodeDataType{ASTNodeDataType::vector_t, 2}; + auto R3 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(3); + auto R2 = ASTNodeDataType::build<ASTNodeDataType::vector_t>(2); data_node->m_data_type = R3; REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, - ASTNodeDataType{ASTNodeDataType::tuple_t, R2}}), - "incompatible dimensions in affectation"); + ASTNodeDataType::build<ASTNodeDataType::tuple_t>(R2)}), + "invalid implicit conversion: R^3 -> R^2"); } SECTION("(B, R, Z) -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::bool_t; + list0_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::double_t; + list1_node->m_data_type = double_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_t}), "invalid implicit conversion: R -> N"); } SECTION("(R, N, Z) -> tuple(Z)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::double_t; + list0_node->m_data_type = double_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::int_t; + list2_node->m_data_type = int_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_t}), "invalid implicit conversion: R -> Z"); } SECTION("(B, N, R) -> tuple(N)") { - data_node->m_data_type = ASTNodeDataType::list_t; + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); - list0_node->m_data_type = ASTNodeDataType::bool_t; + list0_node->m_data_type = bool_dt; data_node->emplace_back(std::move(list0_node)); std::unique_ptr list1_node = std::make_unique<ASTNode>(); - list1_node->m_data_type = ASTNodeDataType::unsigned_int_t; + list1_node->m_data_type = unsigned_int_dt; data_node->emplace_back(std::move(list1_node)); std::unique_ptr list2_node = std::make_unique<ASTNode>(); - list2_node->m_data_type = ASTNodeDataType::double_t; + list2_node->m_data_type = double_dt; data_node->emplace_back(std::move(list2_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_t}), "invalid implicit conversion: R -> N"); } @@ -1492,9 +2130,9 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") SECTION("(type_id_t, type_id_t) -> tuple(type_id_t)") { - auto type_id1 = ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}; - auto type_id2 = ASTNodeDataType{ASTNodeDataType::type_id_t, "bar"}; - data_node->m_data_type = ASTNodeDataType::list_t; + auto type_id1 = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo"); + auto type_id2 = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("bar"); + data_node->m_data_type = list_dt; { std::unique_ptr list0_node = std::make_unique<ASTNode>(); list0_node->m_data_type = type_id1; @@ -1504,7 +2142,7 @@ TEST_CASE("ASTNodeNaturalConversionChecker", "[language]") list1_node->m_data_type = type_id2; data_node->emplace_back(std::move(list1_node)); } - auto tuple_t = ASTNodeDataType{ASTNodeDataType::tuple_t, type_id2}; + auto tuple_t = ASTNodeDataType::build<ASTNodeDataType::tuple_t>(type_id2); REQUIRE_THROWS_WITH((ASTNodeNaturalConversionChecker{*data_node, tuple_t}), "invalid implicit conversion: foo -> bar"); } diff --git a/tests/test_ASTNodeUnaryOperatorExpressionBuilder.cpp b/tests/test_ASTNodeUnaryOperatorExpressionBuilder.cpp index 6c30d3f0a9e8470f7156fe6aa8239002f065aee5..2ebe22cab68dc735836054b58ca4382340e09030 100644 --- a/tests/test_ASTNodeUnaryOperatorExpressionBuilder.cpp +++ b/tests/test_ASTNodeUnaryOperatorExpressionBuilder.cpp @@ -191,35 +191,27 @@ not b; SECTION("Errors") { - SECTION("Invalid unary operator") - { - auto ast = std::make_unique<ASTNode>(); - - REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "unexpected error: undefined unary operator"); - } - - SECTION("Invalid unary operator") - { - auto ast = std::make_unique<ASTNode>(); - - REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "unexpected error: undefined unary operator"); - } - SECTION("Invalid value type for unary minus") { auto ast = std::make_unique<ASTNode>(); ast->set_type<language::unary_minus>(); ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "undefined value type for unary operator"); + REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "undefined unary operator type: - undefined"); } SECTION("errors") { - CHECK_AST_THROWS_WITH(R"(let n:N; not n;)", "invalid implicit conversion: N -> B"); - CHECK_AST_THROWS_WITH(R"(not 2;)", "invalid implicit conversion: Z -> B"); - CHECK_AST_THROWS_WITH(R"(not -2.3;)", "invalid implicit conversion: R -> B"); - CHECK_AST_THROWS_WITH(R"(not "foo";)", "invalid implicit conversion: string -> B"); + auto error_message = [](std::string type_name) { + return std::string{R"(undefined unary operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_AST_THROWS_WITH(R"(let n:N; not n;)", error_message("N")); + CHECK_AST_THROWS_WITH(R"(not 2;)", error_message("Z")); + CHECK_AST_THROWS_WITH(R"(not -2.3;)", error_message("R")); + CHECK_AST_THROWS_WITH(R"(not "foo";)", error_message("string")); } SECTION("Invalid value type for unary not") @@ -228,18 +220,7 @@ not b; ast->set_type<language::unary_not>(); ast->children.emplace_back(std::make_unique<ASTNode>()); - REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "undefined value type for unary operator"); - } - - SECTION("Invalid data type for unary operator") - { - auto ast = std::make_unique<ASTNode>(); - ast->set_type<language::unary_minus>(); - ast->m_data_type = ASTNodeDataType::int_t; - ast->children.emplace_back(std::make_unique<ASTNode>()); - - REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, - "unexpected error: invalid operand type for unary operator"); + REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "undefined unary operator type: not undefined"); } } } diff --git a/tests/test_ASTSymbolInitializationChecker.cpp b/tests/test_ASTSymbolInitializationChecker.cpp index fe59815acf5c65c642a81155fa30abf422a5b471..fb0c8598beb79861f6cbbb66ebe37461c66b61e5 100644 --- a/tests/test_ASTSymbolInitializationChecker.cpp +++ b/tests/test_ASTSymbolInitializationChecker.cpp @@ -301,5 +301,18 @@ let f : R->R, x->x+y; ASTSymbolTableBuilder{*ast}; REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, std::string{"uninitialized symbol 'y'"}); } + + SECTION("expecting a list of identifiers") + { + std::string_view data = R"( +let (x,y,z):R*R*R, x = 3; +)"; + + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + + ASTSymbolTableBuilder{*ast}; + REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, std::string{"expecting a list of identifiers"}); + } } } diff --git a/tests/test_AffectationProcessor.cpp b/tests/test_AffectationProcessor.cpp index c94ecdb7d9a1d320dce9f00a7bd70755fd425749..331046feb18edffb9e110a27a9050cdb94795bea 100644 --- a/tests/test_AffectationProcessor.cpp +++ b/tests/test_AffectationProcessor.cpp @@ -102,7 +102,6 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^1; x[0] = -2.3;", "x", (TinyVector<1>{-2.3})); CHECK_AFFECTATION_RESULT("let x : R^1, x = 0;", "x", (TinyVector<1>{zero})); - CHECK_AFFECTATION_RESULT("let x : R^1; x = 0;", "x", (TinyVector<1>{zero})); } SECTION("R^2") @@ -115,7 +114,6 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^2; x[0] = -0.3; x[1] = 12;", "x", (TinyVector<2>{-0.3, 12})); CHECK_AFFECTATION_RESULT("let x : R^2, x = 0;", "x", (TinyVector<2>{zero})); - CHECK_AFFECTATION_RESULT("let x : R^2; x = 0;", "x", (TinyVector<2>{zero})); } SECTION("R^3") @@ -126,9 +124,51 @@ TEST_CASE("AffectationProcessor", "[language]") (TinyVector<3>{-1, true, false})); CHECK_AFFECTATION_RESULT("let x : R^3; x[0] = -0.3; x[1] = 12; x[2] = 6.2;", "x", (TinyVector<3>{-0.3, 12, 6.2})); - CHECK_AFFECTATION_RESULT("let x : R^3, x = 0;", "x", (TinyVector<3>{zero})); CHECK_AFFECTATION_RESULT("let x : R^3; x = 0;", "x", (TinyVector<3>{zero})); } + + SECTION("R^1x1") + { + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = -1;", "x", (TinyMatrix<1>{-1})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = true;", "x", (TinyMatrix<1>{true})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = false;", "x", (TinyMatrix<1>{false})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = -2.3;", "x", (TinyMatrix<1>{-2.3})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = -1;", "x", (TinyMatrix<1>{-1})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = true;", "x", (TinyMatrix<1>{true})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = false;", "x", (TinyMatrix<1>{false})); + CHECK_AFFECTATION_RESULT("let x : R^1x1; x[0,0] = -2.3;", "x", (TinyMatrix<1>{-2.3})); + + CHECK_AFFECTATION_RESULT("let x : R^1x1; x = 0;", "x", (TinyMatrix<1>{zero})); + } + + SECTION("R^2x2") + { + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 5);", "x", (TinyMatrix<2>{-1, true, 3, 5})); + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (true, false, 1==2, 2==2);", "x", + (TinyMatrix<2>{true, false, false, true})); + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-0.3, 12, 2, -3);", "x", (TinyMatrix<2>{-0.3, 12, 2, -3})); + CHECK_AFFECTATION_RESULT("let x : R^2x2; x[0,0] = -1; x[0,1] = true; x[1,0] = 2; x[1,1] = 3.3;", "x", + (TinyMatrix<2>{-1, true, 2, 3.3})); + CHECK_AFFECTATION_RESULT("let x : R^2x2; x[0,0] = true; x[0,1] = false; x[1,0] = 2.1; x[1,1] = -1;", "x", + (TinyMatrix<2>{true, false, 2.1, -1})); + CHECK_AFFECTATION_RESULT("let x : R^2x2; x[0,0] = -0.3; x[0,1] = 12; x[1,0] = 1.3; x[1,1] = 7;", "x", + (TinyMatrix<2>{-0.3, 12, 1.3, 7})); + + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = 0;", "x", (TinyMatrix<2>{zero})); + } + + SECTION("R^3x3") + { + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-1, true, false, 2, 3.1, 4, -1, true, 2);", "x", + (TinyMatrix<3>{-1, true, false, 2, 3.1, 4, -1, true, 2})); + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-0.3, 12, 6.2, 7.1, 3.2, 2-3, 2, -1, 0);", "x", + (TinyMatrix<3>{-0.3, 12, 6.2, 7.1, 3.2, 2 - 3, 2, -1, 0})); + CHECK_AFFECTATION_RESULT("let x : R^3x3; x[0,0] = -1; x[0,1] = true; x[0,2] = false; x[1,0] = -11; x[1,1] = 4; " + "x[1,2] = 3; x[2,0] = 6; x[2,1] = -3; x[2,2] = 5;", + "x", (TinyMatrix<3>{-1, true, false, -11, 4, 3, 6, -3, 5})); + + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = 0;", "x", (TinyMatrix<3>{zero})); + } } SECTION("+=") @@ -281,6 +321,29 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^3, x = (-0.3, 12, 6.2); x[0] *= -1; x[1] *= -3; x[2] *= 2;", "x", (TinyVector<3>{-0.3 * -1, 12 * -3, 6.2 * 2})); } + + SECTION("R^1x1") + { + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = 2; x *= 2;", "x", (TinyMatrix<1>{TinyMatrix<1>{2} *= 2})); + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = 2; x[0,0] *= 1.3;", "x", (TinyMatrix<1>{2 * 1.3})); + } + + SECTION("R^2x2") + { + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 6); x *= 3;", "x", + (TinyMatrix<2>{TinyMatrix<2>{-1, true, 3, 6} *= 3})); + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 6); x[0,0] *= 2; x[1,1] *= 3;", "x", + (TinyMatrix<2>{-1 * 2, true, 3, 6 * 3})); + } + + SECTION("R^3x3") + { + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-1, true, false, 2, -3, 11, 5, -4, 2); x*=5.2;", "x", + (TinyMatrix<3>{TinyMatrix<3>{-1, true, false, 2, -3, 11, 5, -4, 2} *= 5.2})); + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-0.3, 12, 6.2, 2, -3, 11, 5, -4, 2); x[0,0] *= -1; x[0,1] *= -3; " + "x[0,2] *= 2; x[1,1] *= 2; x[2,1] *= 6; x[2,2] *= 2;", + "x", (TinyMatrix<3>{-0.3 * -1, 12 * -3, 6.2 * 2, 2, -3 * 2, 11, 5, (-4) * 6, 2 * 2})); + } } SECTION("/=") @@ -323,6 +386,28 @@ TEST_CASE("AffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT("let x : R^3, x = (-0.3, 12, 6.2); x[0] /= -1.2; x[1] /= -3.1; x[2] /= 2.4;", "x", (TinyVector<3>{-0.3 / -1.2, 12 / -3.1, 6.2 / 2.4})); } + + SECTION("R^1x1") + { + CHECK_AFFECTATION_RESULT("let x : R^1x1, x = 2; x[0,0] /= 1.3;", "x", (TinyMatrix<1>{2 / 1.3})); + } + + SECTION("R^2x2") + { + CHECK_AFFECTATION_RESULT("let x : R^2x2, x = (-1, true, 3, 1); x[0,0] /= 2; x[0,1] /= 3; x[1,0] /= 0.5; x[1,1] " + "/= 4;", + "x", (TinyMatrix<2>{-1. / 2., true / 3., 3 / 0.5, 1. / 4})); + } + + SECTION("R^3x3") + { + CHECK_AFFECTATION_RESULT("let x : R^3x3, x = (-0.3, 12, 6.2, 1.2, 3, 5, 1, 11, 2); x[0,0] /= -1.2; x[0,1] /= " + "-3.1; x[0,2] /= 2.4; x[1,0] /= -1.6; x[1,1] /= -3.1; x[1,2] /= 2.4; x[2,0] /= 0.4; " + "x[2,1] /= -1.7; x[2,2] /= 1.2;", + "x", + (TinyMatrix<3>{-0.3 / -1.2, 12 / -3.1, 6.2 / 2.4, 1.2 / -1.6, 3 / -3.1, 5 / 2.4, 1 / 0.4, + 11 / -1.7, 2 / 1.2})); + } } SECTION("errors") @@ -331,83 +416,83 @@ TEST_CASE("AffectationProcessor", "[language]") { SECTION("-> B") { - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 1; let b : B; b = n;", "invalid implicit conversion: N -> B"); - CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 1;", "invalid implicit conversion: Z -> B"); - CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 2.3;", "invalid implicit conversion: R -> B"); - CHECK_AFFECTATION_THROWS_WITH("let b : B; b = \"foo\";", "invalid implicit conversion: string -> B"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 1; let b : B; b = n;", "undefined affectation type: B = N"); + CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 1;", "undefined affectation type: B = Z"); + CHECK_AFFECTATION_THROWS_WITH("let b : B; b = 2.3;", "undefined affectation type: B = R"); + CHECK_AFFECTATION_THROWS_WITH("let b : B; b = \"foo\";", "undefined affectation type: B = string"); } SECTION("-> N") { - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2.3;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = \"bar\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2.3;", "undefined affectation type: N = R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = \"bar\";", "undefined affectation type: N = string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += 1.1;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += \"foo\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += 1.1;", "undefined affectation type: N += R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n += \"foo\";", "undefined affectation type: N += string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= 1.1;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= \"bar\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= 1.1;", "undefined affectation type: N -= R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n -= \"bar\";", "undefined affectation type: N -= string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= 2.51;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= \"foobar\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= 2.51;", "undefined affectation type: N *= R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n *= \"foobar\";", "undefined affectation type: N *= string"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= 2.51;", "invalid implicit conversion: R -> N"); - CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= \"foo\";", "invalid implicit conversion: string -> N"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= 2.51;", "undefined affectation type: N /= R"); + CHECK_AFFECTATION_THROWS_WITH("let n : N, n = 2; n /= \"foo\";", "undefined affectation type: N /= string"); } SECTION("-> Z") { - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = -2.3;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = \"foobar\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = -2.3;", "undefined affectation type: Z = R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = \"foobar\";", "undefined affectation type: Z = string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += 1.1;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += \"foo\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += 1.1;", "undefined affectation type: Z += R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z += \"foo\";", "undefined affectation type: Z += string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= 2.1;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= \"bar\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= 2.1;", "undefined affectation type: Z -= R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z -= \"bar\";", "undefined affectation type: Z -= string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= -2.51;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= \"foobar\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= -2.51;", "undefined affectation type: Z *= R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z *= \"foobar\";", "undefined affectation type: Z *= string"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 4; z /= -2.;", "invalid implicit conversion: R -> Z"); - CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z /= \"foo\";", "invalid implicit conversion: string -> Z"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 4; z /= -2.;", "undefined affectation type: Z /= R"); + CHECK_AFFECTATION_THROWS_WITH("let z : Z, z = 2; z /= \"foo\";", "undefined affectation type: Z /= string"); } SECTION("-> R") { - CHECK_AFFECTATION_THROWS_WITH("let x : R, x = \"foobar\";", "invalid implicit conversion: string -> R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.3; x += \"foo\";", "invalid implicit conversion: string -> R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.1; x -= \"bar\";", "invalid implicit conversion: string -> R"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x = \"foobar\";", "undefined affectation type: R = string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.3; x += \"foo\";", "undefined affectation type: R += string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 2.1; x -= \"bar\";", "undefined affectation type: R -= string"); CHECK_AFFECTATION_THROWS_WITH("let x : R, x = 1.2; x *= \"foobar\";", - "invalid implicit conversion: string -> R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R, x =-2.3; x /= \"foo\";", "invalid implicit conversion: string -> R"); + "undefined affectation type: R *= string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R, x =-2.3; x /= \"foo\";", "undefined affectation type: R /= string"); } SECTION("-> R^n") { - CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = \"foobar\";", "invalid implicit conversion: string -> R^2"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = \"foobar\";", "invalid implicit conversion: string -> R^3"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = \"foobar\";", "undefined affectation type: R^2 = string"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = \"foobar\";", "undefined affectation type: R^3 = string"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 3.2;", "invalid implicit conversion: R -> R^2"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 2.3;", "invalid implicit conversion: R -> R^3"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 3.2;", "undefined affectation type: R^2 = R"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 2.3;", "undefined affectation type: R^3 = R"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 4;", "invalid implicit conversion: Z -> R^2"); - CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 3;", "invalid implicit conversion: Z -> R^3"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 4;", "invalid integral value (0 is the solely valid value)"); + CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 3;", "invalid integral value (0 is the solely valid value)"); CHECK_AFFECTATION_THROWS_WITH("let x : R^1, x = 0; let y : R^2, y = x;", - "incompatible dimensions in affectation"); + "undefined affectation type: R^2 = R^1"); CHECK_AFFECTATION_THROWS_WITH("let x : R^1, x = 0; let y : R^3, y = x;", - "incompatible dimensions in affectation"); + "undefined affectation type: R^3 = R^1"); CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 0; let y : R^1, y = x;", - "incompatible dimensions in affectation"); + "undefined affectation type: R^1 = R^2"); CHECK_AFFECTATION_THROWS_WITH("let x : R^2, x = 0; let y : R^3, y = x;", - "incompatible dimensions in affectation"); + "undefined affectation type: R^3 = R^2"); CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 0; let y : R^1, y = x;", - "incompatible dimensions in affectation"); + "undefined affectation type: R^1 = R^3"); CHECK_AFFECTATION_THROWS_WITH("let x : R^3, x = 0; let y : R^2, y = x;", - "incompatible dimensions in affectation"); + "undefined affectation type: R^2 = R^3"); } } } diff --git a/tests/test_AffectationToStringProcessor.cpp b/tests/test_AffectationToStringProcessor.cpp index 4053e023187944c87333f0dd605404104d039aa2..c432cdd00ddcd41213720feb9217ad4f20dfac5e 100644 --- a/tests/test_AffectationToStringProcessor.cpp +++ b/tests/test_AffectationToStringProcessor.cpp @@ -55,17 +55,17 @@ TEST_CASE("ASTAffectationToStringProcessor", "[language]") CHECK_AFFECTATION_RESULT(R"(let s : string; s = 2.3;)", "s", std::to_string(2.3)); { std::ostringstream os; - os << TinyVector<1>{13} << std::ends; + os << TinyVector<1>{13}; CHECK_AFFECTATION_RESULT(R"(let x : R^1, x = 13; let s : string; s = x;)", "s", os.str()); } { std::ostringstream os; - os << TinyVector<2>{2, 3} << std::ends; + os << TinyVector<2>{2, 3}; CHECK_AFFECTATION_RESULT(R"(let x : R^2, x = (2,3); let s : string; s = x;)", "s", os.str()); } { std::ostringstream os; - os << TinyVector<3>{1, 2, 3} << std::ends; + os << TinyVector<3>{1, 2, 3}; CHECK_AFFECTATION_RESULT(R"(let x : R^3, x = (1,2,3); let s : string; s = x;)", "s", os.str()); } } @@ -82,17 +82,17 @@ TEST_CASE("ASTAffectationToStringProcessor", "[language]") (std::string("foo") + std::to_string(2.3))); { std::ostringstream os; - os << "foo" << TinyVector<1>{13} << std::ends; + os << "foo" << TinyVector<1>{13}; CHECK_AFFECTATION_RESULT(R"(let x : R^1, x = 13; let s : string, s="foo"; s += x;)", "s", os.str()); } { std::ostringstream os; - os << "foo" << TinyVector<2>{2, 3} << std::ends; + os << "foo" << TinyVector<2>{2, 3}; CHECK_AFFECTATION_RESULT(R"(let x : R^2, x = (2,3); let s : string, s="foo"; s += x;)", "s", os.str()); } { std::ostringstream os; - os << "foo" << TinyVector<3>{1, 2, 3} << std::ends; + os << "foo" << TinyVector<3>{1, 2, 3}; CHECK_AFFECTATION_RESULT(R"(let x : R^3, x = (1,2,3); let s : string, s="foo"; s += x;)", "s", os.str()); } } diff --git a/tests/test_AffectationToTupleProcessor.cpp b/tests/test_AffectationToTupleProcessor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fccd584c6eb7c164e9b9389d8395f56bc3803ca0 --- /dev/null +++ b/tests/test_AffectationToTupleProcessor.cpp @@ -0,0 +1,243 @@ +#include <catch2/catch.hpp> + +#include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTNodeAffectationExpressionBuilder.hpp> +#include <language/ast/ASTNodeDataTypeBuilder.hpp> +#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp> +#include <language/ast/ASTNodeExpressionBuilder.hpp> +#include <language/ast/ASTNodeTypeCleaner.hpp> +#include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/utils/ASTPrinter.hpp> +#include <utils/Demangle.hpp> + +#include <pegtl/string_input.hpp> + +#include <sstream> + +#define CHECK_AFFECTATION_RESULT(data, variable_name, expected_value) \ + { \ + string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + ASTSymbolTableBuilder{*ast}; \ + ASTNodeDataTypeBuilder{*ast}; \ + \ + ASTNodeDeclarationToAffectationConverter{*ast}; \ + ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ + \ + ASTNodeExpressionBuilder{*ast}; \ + ExecutionPolicy exec_policy; \ + ast->execute(exec_policy); \ + \ + auto symbol_table = ast->m_symbol_table; \ + \ + using namespace TAO_PEGTL_NAMESPACE; \ + position use_position{internal::iterator{"fixture"}, "fixture"}; \ + use_position.byte = 10000; \ + auto [symbol, found] = symbol_table->find(variable_name, use_position); \ + \ + auto attributes = symbol->attributes(); \ + auto value = std::get<decltype(expected_value)>(attributes.value()); \ + \ + REQUIRE(value == expected_value); \ + } + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("ASTAffectationToTupleProcessor", "[language]") +{ + SECTION("Affectations from value") + { + CHECK_AFFECTATION_RESULT(R"( +let s :(R); s = 2.; +)", + "s", (std::vector<double>{2.})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(R); s = 2; +)", + "s", (std::vector<double>{2})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(string); s = 2.; +)", + "s", (std::vector<std::string>{std::to_string(2.)})); + + const std::string x_string = []() -> std::string { + std::ostringstream os; + os << TinyVector<3, double>{1, 2, 3}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let x :R^3, x = (1,2,3); +let s :(string); s = x; +)", + "s", (std::vector<std::string>{x_string})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(R^1); s = 1.3; +)", + "s", (std::vector<TinyVector<1>>{TinyVector<1>{1.3}})); + + const std::string A_string = []() -> std::string { + std::ostringstream os; + os << TinyMatrix<2, double>{1, 2, 3, 4}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let A :R^2x2, A = (1,2,3,4); +let s :(string); s = A; +)", + "s", (std::vector<std::string>{A_string})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(R^1x1); s = 1.3; +)", + "s", (std::vector<TinyMatrix<1>>{TinyMatrix<1>{1.3}})); + } + + SECTION("Affectations from list") + { + CHECK_AFFECTATION_RESULT(R"( +let t :(R); t = (2.,3); +)", + "t", (std::vector<double>{2., 3})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(string); s = (2.,3); +)", + "s", (std::vector<std::string>{std::to_string(2.), std::to_string(3)})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(string); s = (2.,3,"foo"); +)", + "s", + (std::vector<std::string>{std::to_string(2.), std::to_string(3), std::string{"foo"}})); + + const std::string x_string = []() -> std::string { + std::ostringstream os; + os << TinyVector<2, double>{1, 2}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let x : R^2, x = (1,2); +let s : (string); s = (2.,3, x); +)", + "s", (std::vector<std::string>{std::to_string(2.), std::to_string(3), x_string})); + + CHECK_AFFECTATION_RESULT(R"( +let x : R^2, x = (1,2); +let t :(R^2); t = (x,0); +)", + "t", (std::vector<TinyVector<2>>{TinyVector<2>{1, 2}, TinyVector<2>{0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^2); t = ((1,2),0); +)", + "t", (std::vector<TinyVector<2>>{TinyVector<2>{1, 2}, TinyVector<2>{0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^2); t = (0); +)", + "t", (std::vector<TinyVector<2>>{TinyVector<2>{0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^3); t = (0); +)", + "t", (std::vector<TinyVector<3>>{TinyVector<3>{0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let x : R^1, x = 1; +let t :(R^1); t = (x,2); +)", + "t", (std::vector<TinyVector<1>>{TinyVector<1>{1}, TinyVector<1>{2}})); + + const std::string A_string = []() -> std::string { + std::ostringstream os; + os << TinyMatrix<2, double>{1, 2, 3, 4}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let A : R^2x2, A = (1,2,3,4); +let s : (string); s = (2.,3, A); +)", + "s", (std::vector<std::string>{std::to_string(2.), std::to_string(3), A_string})); + + CHECK_AFFECTATION_RESULT(R"( +let A : R^2x2, A = (1,2,3,4); +let t :(R^2x2); t = (A,0); +)", + "t", (std::vector<TinyMatrix<2>>{TinyMatrix<2>{1, 2, 3, 4}, TinyMatrix<2>{0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^2x2); t = ((1,2,3,4),0); +)", + "t", (std::vector<TinyMatrix<2>>{TinyMatrix<2>{1, 2, 3, 4}, TinyMatrix<2>{0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^2x2); t = (0); +)", + "t", (std::vector<TinyMatrix<2>>{TinyMatrix<2>{0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let t :(R^3x3); t = 0; +)", + "t", (std::vector<TinyMatrix<3>>{TinyMatrix<3>{0, 0, 0, 0, 0, 0, 0, 0, 0}})); + + CHECK_AFFECTATION_RESULT(R"( +let x : R^1x1, x = 1; +let t :(R^1x1); t = (x,2); +)", + "t", (std::vector<TinyMatrix<1>>{TinyMatrix<1>{1}, TinyMatrix<1>{2}})); + } + + SECTION("Affectations from tuple") + { + const std::string x_string = []() -> std::string { + std::ostringstream os; + os << TinyVector<3, double>{1, 2, 3}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let x :(R^3), x = ((1,2,3)); +let s :(string); s = x; +)", + "s", (std::vector<std::string>{x_string})); + + const std::string A_string = []() -> std::string { + std::ostringstream os; + os << TinyMatrix<3, double>{1, 2, 3, 4, 5, 6, 7, 8, 9}; + return os.str(); + }(); + + CHECK_AFFECTATION_RESULT(R"( +let A :(R^3x3), A = ((1,2,3,4,5,6,7,8,9)); +let s :(string); s = A; +)", + "s", (std::vector<std::string>{A_string})); + + CHECK_AFFECTATION_RESULT(R"( +let x :(R), x = (1,2,3); +let s :(string); s = x; +)", + "s", + (std::vector<std::string>{std::to_string(1.), std::to_string(2.), std::to_string(3.)})); + + CHECK_AFFECTATION_RESULT(R"( +let n :(N), n = (1,2,3); +let t :(R); t = n; +)", + "t", (std::vector<double>{1, 2, 3})); + + CHECK_AFFECTATION_RESULT(R"( +let s :(N), s = (1,2,3); +let t :(N); t = s; +)", + "t", (std::vector<uint64_t>{1, 2, 3})); + } +} diff --git a/tests/test_Array.cpp b/tests/test_Array.cpp index 3b8dd45878d07747cae3b4e761266c8d9bf14293..0071da34de3f6c5ae3173cb282a5a995462b47c5 100644 --- a/tests/test_Array.cpp +++ b/tests/test_Array.cpp @@ -197,6 +197,23 @@ TEST_CASE("Array", "[utils]") } } + SECTION("checking for Kokkos::View encaspulation") + { + { + Kokkos::View<double*> kokkos_view("anonymous", 10); + for (size_t i = 0; i < kokkos_view.size(); ++i) { + kokkos_view[i] = i; + } + + Array array = encapsulate(kokkos_view); + + REQUIRE(array.size() == kokkos_view.size()); + for (size_t i = 0; i < array.size(); ++i) { + REQUIRE(&array[i] == &kokkos_view[i]); + } + } + } + #ifndef NDEBUG SECTION("checking for bounds violation") { diff --git a/tests/test_ArraySubscriptProcessor.cpp b/tests/test_ArraySubscriptProcessor.cpp index a48390214ac9cfd38bfa063613c0dc2880fe1f6a..43386ea66a5a207d319c8dd681dad4bb4a779b28 100644 --- a/tests/test_ArraySubscriptProcessor.cpp +++ b/tests/test_ArraySubscriptProcessor.cpp @@ -108,6 +108,55 @@ let x2 : R, x2 = x[2]; CHECK_EVALUATION_RESULT(data, "x2", double{3}); } + SECTION("R^1x1 component access") + { + std::string_view data = R"( +let x : R^1x1, x = 1; +let x00: R, x00 = x[0,0]; +)"; + CHECK_EVALUATION_RESULT(data, "x00", double{1}); + } + + SECTION("R^2x2 component access") + { + std::string_view data = R"( +let x : R^2x2, x = (1,2,3,4); +let x00: R, x00 = x[0,0]; +let x01: R, x01 = x[0,1]; +let x10: R, x10 = x[1,0]; +let x11: R, x11 = x[1,1]; +)"; + CHECK_EVALUATION_RESULT(data, "x00", double{1}); + CHECK_EVALUATION_RESULT(data, "x01", double{2}); + CHECK_EVALUATION_RESULT(data, "x10", double{3}); + CHECK_EVALUATION_RESULT(data, "x11", double{4}); + } + + SECTION("R^3x3 component access") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let x00 : R, x00 = x[0,0]; +let x01 : R, x01 = x[0,1]; +let x02 : R, x02 = x[0,2]; +let x10 : R, x10 = x[1,0]; +let x11 : R, x11 = x[1,1]; +let x12 : R, x12 = x[1,2]; +let x20 : R, x20 = x[2,0]; +let x21 : R, x21 = x[2,1]; +let x22 : R, x22 = x[2,2]; +)"; + CHECK_EVALUATION_RESULT(data, "x00", double{1}); + CHECK_EVALUATION_RESULT(data, "x01", double{2}); + CHECK_EVALUATION_RESULT(data, "x02", double{3}); + CHECK_EVALUATION_RESULT(data, "x10", double{4}); + CHECK_EVALUATION_RESULT(data, "x11", double{5}); + CHECK_EVALUATION_RESULT(data, "x12", double{6}); + CHECK_EVALUATION_RESULT(data, "x20", double{7}); + CHECK_EVALUATION_RESULT(data, "x21", double{8}); + CHECK_EVALUATION_RESULT(data, "x22", double{9}); + } + SECTION("R^d component access from integer expression") { std::string_view data = R"( @@ -125,6 +174,23 @@ let z0: R, z0 = z[(2-2)*1]; CHECK_EVALUATION_RESULT(data, "z0", double{8}); } + SECTION("R^dxd component access from integer expression") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let x01: R, x01 = x[3-2-1,2+3-4]; + +let y : R^2x2, y = (2,7,6,-2); +let y11: R, y11 = y[2/2, 3/1-2]; + +let z : R^1x1, z = 8; +let z00: R, z00 = z[(2-2)*1, (3-1)*2-4]; +)"; + CHECK_EVALUATION_RESULT(data, "x01", double{2}); + CHECK_EVALUATION_RESULT(data, "y11", double{-2}); + CHECK_EVALUATION_RESULT(data, "z00", double{8}); + } + SECTION("error invalid index type") { SECTION("R index type") diff --git a/tests/test_BiCGStab.cpp b/tests/test_BiCGStab.cpp index 99b94515ba7795903434db761d509b9052fcaf71..49a6955be4ae1da7c45e036977a3036be0ccf6d3 100644 --- a/tests/test_BiCGStab.cpp +++ b/tests/test_BiCGStab.cpp @@ -45,7 +45,7 @@ TEST_CASE("BiCGStab", "[algebra]") Vector<double> x{5}; x = 0; - BiCGStab<false>{b, A, x, 10, 1e-12}; + BiCGStab{A, x, b, 1e-12, 10, false}; Vector error = x - x_exact; REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x, x))); } @@ -88,7 +88,7 @@ TEST_CASE("BiCGStab", "[algebra]") Vector<double> x{5}; x = 0; - BiCGStab{b, A, x, 1, 1e-12}; + BiCGStab{A, x, b, 1e-12, 1, true}; Vector error = x - x_exact; REQUIRE(std::sqrt((error, error)) > 1E-5 * std::sqrt((x, x))); } diff --git a/tests/test_BinaryExpressionProcessor_logic.cpp b/tests/test_BinaryExpressionProcessor_logic.cpp index f53d8557cc48551a9abe630c3bbda43aed9ae366..a0d6c12e7cbef0e5e90dc1ed0ec6bb57a6a7d406 100644 --- a/tests/test_BinaryExpressionProcessor_logic.cpp +++ b/tests/test_BinaryExpressionProcessor_logic.cpp @@ -42,32 +42,121 @@ TEST_CASE("BinaryExpressionProcessor logic", "[language]") { SECTION("and") { - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=1; n and true;)", "invalid implicit conversion: N -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=2; false and n;)", "invalid implicit conversion: N -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1 and true;)", "invalid implicit conversion: Z -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false and 2;)", "invalid implicit conversion: Z -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1.1 and true;)", "invalid implicit conversion: R -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false and 2e-2;)", "invalid implicit conversion: R -> B"); + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types N and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=1; n and true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and N)"; + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=2; false and n;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types Z and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1 and true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and Z)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false and 2;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types R and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1.1 and true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and R)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false and 2e-2;)", error_message); + } } SECTION("or") { - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=1; n or true;)", "invalid implicit conversion: N -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=2; false or n;)", "invalid implicit conversion: N -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1 or true;)", "invalid implicit conversion: Z -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false or 2;)", "invalid implicit conversion: Z -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1.1 or true;)", "invalid implicit conversion: R -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false or 2e-2;)", "invalid implicit conversion: R -> B"); + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types N and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=1; n or true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and N)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=2; false or n;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types Z and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1 or true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and Z)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false or 2;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types R and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1.1 or true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and R)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false or 2e-2;)", error_message); + } } SECTION("xor") { - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=1; n xor true;)", "invalid implicit conversion: N -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=2; false xor n;)", "invalid implicit conversion: N -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1 xor true;)", "invalid implicit conversion: Z -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false xor 2;)", "invalid implicit conversion: Z -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1.1 xor true;)", "invalid implicit conversion: R -> B"); - CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false xor 2e-2;)", "invalid implicit conversion: R -> B"); + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types N and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=1; n xor true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and N)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(let n:N, n=2; false xor n;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types Z and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1 xor true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and Z)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false xor 2;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types R and B)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(1.1 xor true;)", error_message); + } + { + const std::string error_message = R"(undefined binary operator +note: incompatible operand types B and R)"; + + CHECK_BINARY_EXPRESSION_THROWS_WITH(R"(false xor 2e-2;)", error_message); + } } } } diff --git a/tests/test_BuildInfo.cpp b/tests/test_BuildInfo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b63d0eb9b22beab6d810b806284d6fc85f3c0df9 --- /dev/null +++ b/tests/test_BuildInfo.cpp @@ -0,0 +1,65 @@ +#include <catch2/catch.hpp> + +#include <utils/BuildInfo.hpp> +#include <utils/pugs_build_info.hpp> +#include <utils/pugs_config.hpp> + +#include <sstream> + +#ifdef PUGS_HAS_MPI +#include <mpi.h> +#endif // PUGS_HAS_MPI + +#ifdef PUGS_HAS_PETSC +#include <petsc.h> +#endif // PUGS_HAS_PETSC + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("BuildInfo", "[utils]") +{ + SECTION("type") + { + REQUIRE(BuildInfo::type() == PUGS_BUILD_TYPE); + } + + SECTION("compiler") + { + std::stringstream compiler_info; + compiler_info << PUGS_BUILD_COMPILER << " (" << PUGS_BUILD_COMPILER_VERSION << ")"; + REQUIRE(BuildInfo::compiler() == compiler_info.str()); + } + + SECTION("kokkos") + { + REQUIRE(BuildInfo::kokkosDevices() == PUGS_BUILD_KOKKOS_DEVICES); + } + + SECTION("mpi") + { +#ifdef PUGS_HAS_MPI + const std::string mpi_library = []() { + int length; + char mpi_version[MPI_MAX_LIBRARY_VERSION_STRING]; + MPI_Get_library_version(mpi_version, &length); + return std::string(mpi_version); + }(); + + REQUIRE(BuildInfo::mpiLibrary() == mpi_library); +#else + REQUIRE(BuildInfo::mpiLibrary() == "none"); +#endif // PUGS_HAS_MPI + } + + SECTION("petsc") + { +#ifdef PUGS_HAS_PETSC + const std::string petsc_library = std::to_string(PETSC_VERSION_MAJOR) + "." + std::to_string(PETSC_VERSION_MINOR) + + "." + std::to_string(PETSC_VERSION_SUBMINOR); + + REQUIRE(BuildInfo::petscLibrary() == petsc_library); +#else + REQUIRE(BuildInfo::petscLibrary() == "none"); +#endif // PUGS_HAS_PETSC + } +} diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index bf0445d6aadf5cfdbefed9134dafe0e89dd8d290..b00a9649b6cf8a5f3b57c27d9d63e74ad5ed68b6 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -5,15 +5,16 @@ // clazy:excludeall=non-pod-global-static template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = {ASTNodeDataType::type_id_t, - "shared_const_double"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_const_double"); template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<double>> = {ASTNodeDataType::type_id_t, "shared_double"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<double>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_double"); template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<uint64_t>> = {ASTNodeDataType::type_id_t, - "shared_uint64_t"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<uint64_t>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_uint64_t"); TEST_CASE("BuiltinFunctionEmbedder", "[language]") { @@ -58,6 +59,52 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t); } + SECTION("R*R^2 -> R^2") + { + std::function c = [](double a, TinyVector<2> x) -> TinyVector<2> { return a * x; }; + + BuiltinFunctionEmbedder<TinyVector<2>(double, TinyVector<2>)> embedded_c{c}; + + double a_arg = 2.3; + TinyVector<2> x_arg{3, 2}; + + std::vector<DataVariant> args; + args.push_back(a_arg); + args.push_back(x_arg); + + DataVariant result = embedded_c.apply(args); + + REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg)); + REQUIRE(embedded_c.numberOfParameters() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t); + REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t); + REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t); + } + + SECTION("R^2x2*R^2 -> R^2") + { + std::function c = [](TinyMatrix<2> A, TinyVector<2> x) -> TinyVector<2> { return A * x; }; + + BuiltinFunctionEmbedder<TinyVector<2>(TinyMatrix<2>, TinyVector<2>)> embedded_c{c}; + + TinyMatrix<2> a_arg = {2.3, 1, -2, 3}; + TinyVector<2> x_arg{3, 2}; + + std::vector<DataVariant> args; + args.push_back(a_arg); + args.push_back(x_arg); + + DataVariant result = embedded_c.apply(args); + + REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg)); + REQUIRE(embedded_c.numberOfParameters() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t); + REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::matrix_t); + REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t); + } + SECTION("POD BuiltinFunctionEmbedder") { std::function c = [](double x, uint64_t i) -> bool { return x > i; }; diff --git a/tests/test_BuiltinFunctionRegister.hpp b/tests/test_BuiltinFunctionRegister.hpp index 1583f1df99d7eddefc3337529611a5d5720912a0..bc36e2ea29d37cbcd2464ddc1e080bff9e0dd591 100644 --- a/tests/test_BuiltinFunctionRegister.hpp +++ b/tests/test_BuiltinFunctionRegister.hpp @@ -7,8 +7,8 @@ #include <utils/Exceptions.hpp> template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = {ASTNodeDataType::type_id_t, - "builtin_t"}; +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t"); const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const double>>; namespace test_only @@ -68,6 +68,28 @@ class test_BuiltinFunctionRegister std::make_shared<BuiltinFunctionEmbedder<double(TinyVector<3>, TinyVector<2>)>>( [](TinyVector<3> x, TinyVector<2> y) -> double { return x[0] * y[1] + (y[0] - x[2]) * x[1]; }))); + m_name_builtin_function_map.insert( + std::make_pair("RtoR11", std::make_shared<BuiltinFunctionEmbedder<TinyMatrix<1>(double)>>( + [](double r) -> TinyMatrix<1> { return {r}; }))); + + m_name_builtin_function_map.insert( + std::make_pair("R11toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<1>)>>( + [](TinyMatrix<1> x) -> double { return x(0, 0); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R22toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<2>)>>( + [](TinyMatrix<2> x) -> double { return x(0, 0) + x(0, 1) + x(1, 0) + x(1, 1); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R33toR", std::make_shared<BuiltinFunctionEmbedder<double(const TinyMatrix<3>&)>>( + [](const TinyMatrix<3>& x) -> double { return x(0, 0) + x(1, 1) + x(2, 2); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R33R22toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<3>, TinyMatrix<2>)>>( + [](TinyMatrix<3> x, TinyMatrix<2> y) -> double { + return (x(0, 0) + x(1, 1) + x(2, 2)) * (y(0, 0) + y(0, 1) + y(1, 0) + y(1, 1)); + }))); + m_name_builtin_function_map.insert( std::make_pair("fidToR", std::make_shared<BuiltinFunctionEmbedder<double(const FunctionSymbolId&)>>( [](const FunctionSymbolId&) -> double { return 0; }))); @@ -116,6 +138,21 @@ class test_BuiltinFunctionRegister m_name_builtin_function_map.insert( std::make_pair("tuple_R3ToR", std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyVector<3>>)>>( [](const std::vector<TinyVector<3>>&) -> double { return 0; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R11ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<1>>&)>>( + [](const std::vector<TinyMatrix<1>>&) -> double { return 1; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R22ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<2>>&)>>( + [](const std::vector<TinyMatrix<2>>&) -> double { return 1; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R33ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<3>>)>>( + [](const std::vector<TinyMatrix<3>>&) -> double { return 0; }))); } void @@ -127,7 +164,7 @@ class test_BuiltinFunctionRegister throw UnexpectedError("cannot add '" + builtin_data_type.nameOfTypeId() + "' type for testing"); } - i_symbol->attributes().setDataType(ASTNodeDataType::type_name_id_t); + i_symbol->attributes().setDataType(ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>()); i_symbol->attributes().setIsInitialized(); i_symbol->attributes().value() = symbol_table.typeEmbedderTable().size(); symbol_table.typeEmbedderTable().add(std::make_shared<TypeDescriptor>(builtin_data_type.nameOfTypeId())); @@ -165,7 +202,7 @@ class test_BuiltinFunctionRegister throw ParseError(error_message.str(), root_node.begin()); } - i_symbol->attributes().setDataType(ASTNodeDataType::builtin_function_t); + i_symbol->attributes().setDataType(ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>()); i_symbol->attributes().setIsInitialized(); i_symbol->attributes().value() = builtin_function_embedder_table.size(); diff --git a/tests/test_PCG.cpp b/tests/test_CG.cpp similarity index 91% rename from tests/test_PCG.cpp rename to tests/test_CG.cpp index d4564761d8c564082e4668360e5a2e2a8ecce008..75ab83063b9c19121e598ec68a5260ce95b02ae8 100644 --- a/tests/test_PCG.cpp +++ b/tests/test_CG.cpp @@ -1,11 +1,11 @@ #include <catch2/catch.hpp> +#include <algebra/CG.hpp> #include <algebra/CRSMatrix.hpp> -#include <algebra/PCG.hpp> // clazy:excludeall=non-pod-global-static -TEST_CASE("PCG", "[algebra]") +TEST_CASE("CG", "[algebra]") { SECTION("no preconditionner") { @@ -45,7 +45,7 @@ TEST_CASE("PCG", "[algebra]") Vector<double> x{5}; x = 0; - PCG{b, A, A, x, 10, 1e-12}; + CG{A, x, b, 1e-12, 10, true}; Vector error = x - x_exact; REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x, x))); } @@ -62,7 +62,7 @@ TEST_CASE("PCG", "[algebra]") Vector<double> x{5}; x = 0; - PCG<false>{b, A, A, x, 10, 1e-12}; + CG{A, x, b, 1e-12, 10}; REQUIRE(std::sqrt((x, x)) == 0); } @@ -104,7 +104,7 @@ TEST_CASE("PCG", "[algebra]") Vector<double> x{5}; x = 0; - PCG<false>{b, A, A, x, 1, 1e-12}; + CG{A, x, b, 1e-12, 1, false}; Vector error = x - x_exact; REQUIRE(std::sqrt((error, error)) > 1E-10 * std::sqrt((x, x))); } diff --git a/tests/test_CRSGraph.cpp b/tests/test_CRSGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c2fc7f287249468886f557d30bbb5740216f114 --- /dev/null +++ b/tests/test_CRSGraph.cpp @@ -0,0 +1,39 @@ +#include <catch2/catch.hpp> + +#include <utils/CRSGraph.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("CRSGraph", "[utils]") +{ + Array<int> entries{5}; + Array<int> neighbors{9}; + + entries[0] = 0; + neighbors[0] = 0; + neighbors[1] = 1; + + entries[1] = 2; + neighbors[2] = 1; + neighbors[3] = 3; + + entries[2] = 4; + neighbors[4] = 2; + neighbors[5] = 1; + neighbors[6] = 3; + + entries[3] = 7; + neighbors[7] = 0; + neighbors[8] = 1; + + entries[4] = 9; + + CRSGraph graph(entries, neighbors); + + REQUIRE(graph.numberOfNodes() == 4); + + REQUIRE(entries.size() == graph.entries().size()); + REQUIRE(&entries[0] == &graph.entries()[0]); + REQUIRE(neighbors.size() == graph.neighbors().size()); + REQUIRE(&neighbors[0] == &graph.neighbors()[0]); +} diff --git a/tests/test_CRSMatrix.cpp b/tests/test_CRSMatrix.cpp index 7a7a943d14c943f5a315544fc6eab947ffc28463..6c8584796b8c71991941870c723cc23617c0c2ac 100644 --- a/tests/test_CRSMatrix.cpp +++ b/tests/test_CRSMatrix.cpp @@ -94,7 +94,7 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[4] == -2); } - SECTION("matrix vector product (complet)") + SECTION("matrix vector product (complete)") { SparseMatrixDescriptor<int, uint8_t> S{4}; S(0, 0) = 1; @@ -129,6 +129,60 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[3] == 150); } + SECTION("check values") + { + SparseMatrixDescriptor<int, uint8_t> S{4}; + S(3, 0) = 13; + S(0, 0) = 1; + S(0, 1) = 2; + S(1, 1) = 6; + S(1, 2) = 7; + S(2, 2) = 11; + S(3, 2) = 15; + S(2, 0) = 9; + S(3, 3) = 16; + S(2, 3) = 12; + S(0, 3) = 4; + S(1, 0) = 5; + S(2, 1) = 10; + + CRSMatrix<int, uint8_t> A{S}; + + auto values = A.values(); + REQUIRE(values.size() == 13); + REQUIRE(values[0] == 1); + REQUIRE(values[1] == 2); + REQUIRE(values[2] == 4); + REQUIRE(values[3] == 5); + REQUIRE(values[4] == 6); + REQUIRE(values[5] == 7); + REQUIRE(values[6] == 9); + REQUIRE(values[7] == 10); + REQUIRE(values[8] == 11); + REQUIRE(values[9] == 12); + REQUIRE(values[10] == 13); + REQUIRE(values[11] == 15); + REQUIRE(values[12] == 16); + + auto row_indices = A.rowIndices(); + + REQUIRE(row_indices.size() == 5); + + REQUIRE(A.row(0).colidx(0) == 0); + REQUIRE(A.row(0).colidx(1) == 1); + REQUIRE(A.row(0).colidx(2) == 3); + REQUIRE(A.row(1).colidx(0) == 0); + REQUIRE(A.row(1).colidx(1) == 1); + REQUIRE(A.row(1).colidx(2) == 2); + REQUIRE(A.row(2).colidx(0) == 0); + REQUIRE(A.row(2).colidx(1) == 1); + REQUIRE(A.row(2).colidx(2) == 2); + REQUIRE(A.row(2).colidx(3) == 3); + REQUIRE(A.row(3).colidx(0) == 0); + REQUIRE(A.row(3).colidx(1) == 2); + REQUIRE(A.row(3).colidx(2) == 3); + } + #ifndef NDEBUG SECTION("incompatible runtime matrix/vector product") { diff --git a/tests/test_CastArray.cpp b/tests/test_CastArray.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a014045cfaf19a039dcaf0be97edaa81bea8522 --- /dev/null +++ b/tests/test_CastArray.cpp @@ -0,0 +1,88 @@ +#include <catch2/catch.hpp> + +#include <utils/ArrayUtils.hpp> +#include <utils/CastArray.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("CastArray", "[utils]") +{ + SECTION("explicit cast Array -> CastArray") + { + Array<double> x_double{5}; + x_double[0] = 1; + x_double[1] = 2; + x_double[2] = 3; + x_double[3] = 4; + x_double[4] = 5; + + CastArray<double, char> x_char{x_double}; + + REQUIRE(x_char.size() * sizeof(char) == x_double.size() * sizeof(double)); + + Array<char> y_char{x_char.size()}; + for (size_t i = 0; i < x_char.size(); ++i) { + y_char[i] = x_char[i]; + } + + CastArray<char, double> y_double{y_char}; + REQUIRE(y_char.size() * sizeof(char) == y_double.size() * sizeof(double)); + + REQUIRE(&(y_double[0]) != &(x_double[0])); + + for (size_t i = 0; i < y_double.size(); ++i) { + REQUIRE(y_double[i] == x_double[i]); + } + } + + SECTION("explicit cast value -> CastArray") + { + double x = 3; + + CastArray<double, char> x_char(x); + + REQUIRE(x_char.size() * sizeof(char) == sizeof(double)); + } + + SECTION("invalid cast array") + { + Array<char> x_char{13}; + + REQUIRE_THROWS_WITH((CastArray<char, double>{x_char}), + "unexpected error: cannot cast array to the chosen data type"); + } + + SECTION("cast array utilities") + { + SECTION("Array -> CastArray") + { + Array<double> x_double{5}; + x_double[0] = 1.3; + x_double[1] = 3.2; + x_double[2] = -4; + x_double[3] = 6.2; + x_double[4] = -1.6; + + CastArray<double, short> x_short{x_double}; + auto x_short_from = cast_array_to<short>::from(x_double); + + REQUIRE(x_short_from.size() == x_short.size()); + for (size_t i = 0; i < x_short_from.size(); ++i) { + REQUIRE(x_short_from[i] == x_short[i]); + } + } + + SECTION("Value -> CastArray") + { + double x = 3.14; + + CastArray<double, short> x_short{x}; + auto x_short_from = cast_value_to<short>::from(x); + + REQUIRE(x_short_from.size() == x_short.size()); + for (size_t i = 0; i < x_short_from.size(); ++i) { + REQUIRE(x_short_from[i] == x_short[i]); + } + } + } +} diff --git a/tests/test_ConcatExpressionProcessor.cpp b/tests/test_ConcatExpressionProcessor.cpp index adc5f49e0360b2e1ea3e784ea83fae60cd157d69..57b68c4e7206c9eebd59942cddf2829f2fe236ab 100644 --- a/tests/test_ConcatExpressionProcessor.cpp +++ b/tests/test_ConcatExpressionProcessor.cpp @@ -71,4 +71,28 @@ TEST_CASE("ConcatExpressionProcessor", "[language]") { CHECK_CONCAT_EXPRESSION_RESULT(R"(let s:string, s = "foo_"; s = s+true;)", "s", std::string{"foo_1"}); } + + SECTION("string + R^1") + { + std::ostringstream os; + os << "foo_" << TinyVector<1>{1}; + + CHECK_CONCAT_EXPRESSION_RESULT(R"(let x:R^1, x = 1; let s:string, s = "foo_"; s = s+x;)", "s", os.str()); + } + + SECTION("string + R^2") + { + std::ostringstream os; + os << "foo_" << TinyVector<2>{1, 2}; + + CHECK_CONCAT_EXPRESSION_RESULT(R"(let x:R^2, x = (1,2); let s:string, s = "foo_"; s = s+x;)", "s", os.str()); + } + + SECTION("string + R^3") + { + std::ostringstream os; + os << "foo_" << TinyVector<3>{1, 2, 3}; + + CHECK_CONCAT_EXPRESSION_RESULT(R"(let x:R^3, x = (1,2,3); let s:string, s = "foo_"; s = s+x;)", "s", os.str()); + } } diff --git a/tests/test_ConsoleManager.cpp b/tests/test_ConsoleManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4387d7042065df80406ba3ba1c06ba44ec3f9e97 --- /dev/null +++ b/tests/test_ConsoleManager.cpp @@ -0,0 +1,32 @@ +#include <catch2/catch.hpp> + +#include <utils/ConsoleManager.hpp> + +#include <rang.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("ConsoleManager", "[utils]") +{ + SECTION("is terminal") + { + const bool is_terminal = rang::rang_implementation::isTerminal(std::cout.rdbuf()); + + REQUIRE(is_terminal == ConsoleManager::isTerminal(std::cout)); + } + + SECTION("control settings") + { + const rang::control saved_control = rang::rang_implementation::controlMode(); + + ConsoleManager::init(true); + + REQUIRE(rang::rang_implementation::controlMode() == rang::control::Force); + + ConsoleManager::init(false); + + REQUIRE(rang::rang_implementation::controlMode() == rang::control::Off); + + rang::setControlMode(saved_control); + } +} diff --git a/tests/test_DataVariant.cpp b/tests/test_DataVariant.cpp index 1fa8cf72be599782c25fd88f22613e465df818e8..92ca82b9aeb7ee68c9698f973ce48d457fd27696 100644 --- a/tests/test_DataVariant.cpp +++ b/tests/test_DataVariant.cpp @@ -34,7 +34,7 @@ TEST_CASE("DataVariant", "[language]") REQUIRE(std::get<std::vector<double>>(aggregate[2]) == std::vector<double>{1, 2.7}); } - SECTION("Copy") + SECTION("copy") { AggregateDataVariant aggregate_copy{aggregate}; @@ -48,5 +48,40 @@ TEST_CASE("DataVariant", "[language]") REQUIRE(std::get<int64_t>(aggregate[1]) == std::get<int64_t>(aggregate_copy[1])); REQUIRE(std::get<std::vector<double>>(aggregate[2]) == std::get<std::vector<double>>(aggregate_copy[2])); } + + SECTION("affectation") + { + AggregateDataVariant aggregate_copy; + aggregate_copy = aggregate; + + REQUIRE(aggregate.size() == aggregate_copy.size()); + + for (size_t i = 0; i < aggregate.size(); ++i) { + REQUIRE(aggregate[i].index() == aggregate_copy[i].index()); + } + + REQUIRE(std::get<double>(aggregate[0]) == std::get<double>(aggregate_copy[0])); + REQUIRE(std::get<int64_t>(aggregate[1]) == std::get<int64_t>(aggregate_copy[1])); + REQUIRE(std::get<std::vector<double>>(aggregate[2]) == std::get<std::vector<double>>(aggregate_copy[2])); + } + + SECTION("move affectation") + { + AggregateDataVariant aggregate_move_copy; + { + AggregateDataVariant aggregate_copy{aggregate}; + aggregate_move_copy = std::move(aggregate_copy); + } + + REQUIRE(aggregate.size() == aggregate_move_copy.size()); + + for (size_t i = 0; i < aggregate.size(); ++i) { + REQUIRE(aggregate[i].index() == aggregate_move_copy[i].index()); + } + + REQUIRE(std::get<double>(aggregate[0]) == std::get<double>(aggregate_move_copy[0])); + REQUIRE(std::get<int64_t>(aggregate[1]) == std::get<int64_t>(aggregate_move_copy[1])); + REQUIRE(std::get<std::vector<double>>(aggregate[2]) == std::get<std::vector<double>>(aggregate_move_copy[2])); + } } } diff --git a/tests/test_Demangle.cpp b/tests/test_Demangle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3489fd2d629fe835e3ce247a8ac8008bb06ff9d --- /dev/null +++ b/tests/test_Demangle.cpp @@ -0,0 +1,37 @@ +#include <catch2/catch.hpp> + +#include <utils/Demangle.hpp> + +#include <cxxabi.h> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("Demangle", "[utils]") +{ + SECTION("demangle success") + { + const std::string mangled = typeid(std::string).name(); + + int status = -1; + char* cxa_demangled = abi::__cxa_demangle(mangled.data(), NULL, NULL, &status); + + REQUIRE(status == 0); + + std::string demangled{cxa_demangled}; + free(cxa_demangled); + + REQUIRE(demangled == demangle<std::string>()); + } + + SECTION("demangle failed") + { + const std::string mangled = "not_mangled"; + + int status = -1; + abi::__cxa_demangle(mangled.data(), NULL, NULL, &status); + + REQUIRE(status != 0); + + REQUIRE((std::string{"not_mangled"} == demangle("not_mangled"))); + } +} diff --git a/tests/test_EscapedString.cpp b/tests/test_EscapedString.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3dca01496813ca51e6b52afd1be3c446bc7df54b --- /dev/null +++ b/tests/test_EscapedString.cpp @@ -0,0 +1,22 @@ +#include <catch2/catch.hpp> + +#include <utils/EscapedString.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("EscapedString", "[utils]") +{ + SECTION("escape string") + { + const std::string s = "foo\'\\\"\?\a\b\f\n\r\t\vbar"; + + REQUIRE(escapeString(s) == R"(foo\'\\\"\?\a\b\f\n\r\t\vbar)"); + } + + SECTION("unescape string") + { + const std::string s = R"(foo\'\\\"\?\a\b\f\n\r\t\vbar)"; + + REQUIRE(unescapeString(s) == std::string{"foo\'\\\"\?\a\b\f\n\r\t\vbar"}); + } +} diff --git a/tests/test_Exceptions.cpp b/tests/test_Exceptions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25b181ae165aef26cd25942cc52741a2606b50ac --- /dev/null +++ b/tests/test_Exceptions.cpp @@ -0,0 +1,20 @@ +#include <catch2/catch.hpp> + +#include <utils/Exceptions.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("Exceptions", "[utils]") +{ + SECTION("exceptions message") + { + RawError raw_error{"a raw error"}; + REQUIRE(std::string{raw_error.what()} == "a raw error"); + + UnexpectedError unexpected_error{"an unexpected error"}; + REQUIRE(std::string{unexpected_error.what()} == "unexpected error: an unexpected error"); + + NotImplementedError not_implemented_error{"not implemented error"}; + REQUIRE(std::string{not_implemented_error.what()} == "not implemented yet: not implemented error"); + } +} diff --git a/tests/test_FunctionArgumentConverter.cpp b/tests/test_FunctionArgumentConverter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0a191fc5c2474a65c34c8c5e5b822b5046cb8a2 --- /dev/null +++ b/tests/test_FunctionArgumentConverter.cpp @@ -0,0 +1,186 @@ +#include <catch2/catch.hpp> + +#include <language/node_processor/FunctionArgumentConverter.hpp> +#include <language/utils/SymbolTable.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("FunctionArgumentConverter", "[language]") +{ + ExecutionPolicy::Context context(0, std::make_shared<ExecutionPolicy::Context::Values>(3)); + ExecutionPolicy execution_policy(ExecutionPolicy{}, context); + + SECTION("FunctionArgumentToStringConverter") + { + const std::string s{"foo"}; + FunctionArgumentToStringConverter converter0{0}; + converter0.convert(execution_policy, s); + + const TinyVector<3> X{1, 3.2, 4}; + FunctionArgumentToStringConverter converter1{1}; + converter1.convert(execution_policy, X); + std::ostringstream os_X; + os_X << X; + + const double x = 3.2; + FunctionArgumentToStringConverter converter2{2}; + converter2.convert(execution_policy, x); + + REQUIRE(std::get<std::string>(execution_policy.currentContext()[0]) == s); + REQUIRE(std::get<std::string>(execution_policy.currentContext()[1]) == os_X.str()); + REQUIRE(std::get<std::string>(execution_policy.currentContext()[2]) == std::to_string(x)); + } + + SECTION("FunctionArgumentConverter") + { + const double double_value = 1.7; + FunctionArgumentConverter<double, double> converter0{0}; + converter0.convert(execution_policy, double{double_value}); + + const uint64_t uint64_value = 3; + FunctionArgumentConverter<double, uint64_t> converter1{1}; + converter1.convert(execution_policy, uint64_value); + + const bool bool_value = false; + FunctionArgumentConverter<uint64_t, bool> converter2{2}; + converter2.convert(execution_policy, bool_value); + + REQUIRE(std::get<double>(execution_policy.currentContext()[0]) == double_value); + REQUIRE(std::get<double>(execution_policy.currentContext()[1]) == static_cast<double>(uint64_value)); + REQUIRE(std::get<uint64_t>(execution_policy.currentContext()[2]) == static_cast<uint64_t>(bool_value)); + } + + SECTION("FunctionTinyVectorArgumentConverter") + { + const TinyVector<3> x3{1.7, 2.9, -3}; + FunctionTinyVectorArgumentConverter<TinyVector<3>, TinyVector<3>> converter0{0}; + converter0.convert(execution_policy, TinyVector{x3}); + + const double x1 = 6.3; + FunctionTinyVectorArgumentConverter<TinyVector<1>, double> converter1{1}; + converter1.convert(execution_policy, double{x1}); + + AggregateDataVariant values{std::vector<DataVariant>{6.3, 3.2, 4ul}}; + FunctionTinyVectorArgumentConverter<TinyVector<3>, TinyVector<3>> converter2{2}; + converter2.convert(execution_policy, values); + + REQUIRE(std::get<TinyVector<3>>(execution_policy.currentContext()[0]) == x3); + REQUIRE(std::get<TinyVector<1>>(execution_policy.currentContext()[1]) == TinyVector<1>{x1}); + REQUIRE(std::get<TinyVector<3>>(execution_policy.currentContext()[2]) == TinyVector<3>{6.3, 3.2, 4ul}); + + AggregateDataVariant bad_values{std::vector<DataVariant>{6.3, 3.2, std::string{"bar"}}}; + + REQUIRE_THROWS_WITH(converter2.convert(execution_policy, bad_values), std::string{"unexpected error: "} + + demangle<std::string>() + + " unexpected aggregate value type"); + } + + SECTION("FunctionTinyMatrixArgumentConverter") + { + const TinyMatrix<3> x3{1.7, 2.9, -3, 4, 5.2, 6.1, -7, 8.3, 9.05}; + FunctionTinyMatrixArgumentConverter<TinyMatrix<3>, TinyMatrix<3>> converter0{0}; + converter0.convert(execution_policy, TinyMatrix{x3}); + + const double x1 = 6.3; + FunctionTinyMatrixArgumentConverter<TinyMatrix<1>, double> converter1{1}; + converter1.convert(execution_policy, double{x1}); + + AggregateDataVariant values{std::vector<DataVariant>{6.3, 3.2, 4ul, 2.3, -3.1, 6.7, 3.6, 2ul, 1.1}}; + FunctionTinyMatrixArgumentConverter<TinyMatrix<3>, TinyMatrix<3>> converter2{2}; + converter2.convert(execution_policy, values); + + REQUIRE(std::get<TinyMatrix<3>>(execution_policy.currentContext()[0]) == x3); + REQUIRE(std::get<TinyMatrix<1>>(execution_policy.currentContext()[1]) == TinyMatrix<1>{x1}); + REQUIRE(std::get<TinyMatrix<3>>(execution_policy.currentContext()[2]) == + TinyMatrix<3>{6.3, 3.2, 4ul, 2.3, -3.1, 6.7, 3.6, 2ul, 1.1}); + + AggregateDataVariant bad_values{std::vector<DataVariant>{6.3, 3.2, std::string{"bar"}, true}}; + + REQUIRE_THROWS_WITH(converter2.convert(execution_policy, bad_values), std::string{"unexpected error: "} + + demangle<std::string>() + + " unexpected aggregate value type"); + } + + SECTION("FunctionTupleArgumentConverter") + { + const TinyVector<3> x3{1.7, 2.9, -3}; + FunctionTupleArgumentConverter<TinyVector<3>, TinyVector<3>> converter0{0}; + converter0.convert(execution_policy, TinyVector{x3}); + + const double a = 1.2; + const double b = -3.5; + const double c = 2.6; + FunctionTupleArgumentConverter<double, double> converter1{1}; + converter1.convert(execution_policy, std::vector{a, b, c}); + + const uint64_t i = 1; + const uint64_t j = 3; + const uint64_t k = 6; + FunctionTupleArgumentConverter<double, uint64_t> converter2{2}; + converter2.convert(execution_policy, std::vector<uint64_t>{i, j, k}); + + REQUIRE(std::get<std::vector<TinyVector<3>>>(execution_policy.currentContext()[0]) == + std::vector<TinyVector<3>>{x3}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{a, b, c}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[2]) == std::vector<double>{i, j, k}); + + converter1.convert(execution_policy, a); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{a}); + + converter1.convert(execution_policy, j); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{j}); + + // Errors + REQUIRE_THROWS_WITH(converter0.convert(execution_policy, j), + "unexpected error: cannot convert 'unsigned long' to 'TinyVector<3ul, double>'"); + } + + SECTION("FunctionListArgumentConverter") + { + const uint64_t i = 3; + FunctionListArgumentConverter<double, double> converter0{0}; + converter0.convert(execution_policy, i); + + const double a = 6.3; + const double b = -1.3; + const double c = 3.6; + FunctionListArgumentConverter<double, double> converter1{1}; + converter1.convert(execution_policy, std::vector<double>{a, b, c}); + + AggregateDataVariant v{std::vector<DataVariant>{1ul, 2.3, -3l}}; + FunctionListArgumentConverter<double, double> converter2{2}; + converter2.convert(execution_policy, v); + + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[0]) == std::vector<double>{i}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[1]) == std::vector<double>{a, b, c}); + REQUIRE(std::get<std::vector<double>>(execution_policy.currentContext()[2]) == std::vector<double>{1ul, 2.3, -3l}); + + FunctionListArgumentConverter<TinyVector<2>, TinyVector<2>> converterR2_0{0}; + converterR2_0.convert(execution_policy, TinyVector<2>{1, 3.2}); + + FunctionListArgumentConverter<TinyVector<2>, TinyVector<2>> converterR2_1{1}; + converterR2_1.convert(execution_policy, std::vector{TinyVector<2>{1, 3.2}, TinyVector<2>{-1, 0.2}}); + + AggregateDataVariant v_R2{std::vector<DataVariant>{TinyVector<2>{-3, 12.2}, TinyVector<2>{2, 1.2}}}; + FunctionListArgumentConverter<TinyVector<2>, TinyVector<2>> converterR2_2{2}; + converterR2_2.convert(execution_policy, v_R2); + + REQUIRE(std::get<std::vector<TinyVector<2>>>(execution_policy.currentContext()[0]) == + std::vector<TinyVector<2>>{TinyVector<2>{1, 3.2}}); + REQUIRE(std::get<std::vector<TinyVector<2>>>(execution_policy.currentContext()[1]) == + std::vector<TinyVector<2>>{TinyVector<2>{1, 3.2}, TinyVector<2>{-1, 0.2}}); + REQUIRE(std::get<std::vector<TinyVector<2>>>(execution_policy.currentContext()[2]) == + std::vector<TinyVector<2>>{TinyVector<2>{-3, 12.2}, TinyVector<2>{2, 1.2}}); + } + + SECTION("FunctionArgumentToFunctionSymbolIdConverter") + { + std::shared_ptr symbol_table = std::make_shared<SymbolTable>(); + + const uint64_t f_id = 3; + FunctionArgumentToFunctionSymbolIdConverter converter0{0, symbol_table}; + converter0.convert(execution_policy, f_id); + + REQUIRE(std::get<FunctionSymbolId>(execution_policy.currentContext()[0]).id() == f_id); + } +} diff --git a/tests/test_FunctionProcessor.cpp b/tests/test_FunctionProcessor.cpp index 59386bbae65a807841ead561899e0a61719a82d3..64fe6776ee427db8b2e3641ee952cff2477279ca 100644 --- a/tests/test_FunctionProcessor.cpp +++ b/tests/test_FunctionProcessor.cpp @@ -406,6 +406,79 @@ let fx:R^3, fx = f(3); } } + SECTION("R^dxd functions (single value)") + { + SECTION(" R^1x1 -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> 2*x; +let x:R^1x1, x = 3; + +let fx:R^1x1, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); + } + + SECTION(" R^2x2 -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> 2*x; +let x:R^2x2, x = (3, 7, 6, -2); + +let fx:R^2x2, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<2>{3, 7, 6, -2})); + } + + SECTION(" R^3x3 -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> 2*x; +let x:R^3x3, x = (2, 4, 7, 1, 3, 5, -6, 2, -3); + +let fx:R^3x3, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<3>{2, 4, 7, 1, 3, 5, -6, 2, -3})); + } + + SECTION(" R -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> 2*x; +let x:R, x = 3; + +let fx:R^1x1, fx = f(x); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); + } + + SECTION(" R*R -> R^2x2") + { + std::string_view data = R"( +let f : R*R -> R^2x2, (x,y) -> (2*x, 3*y, 5*(x-y), 2*x-y); +let fx:R^2x2, fx = f(2, 3); +)"; + + const double x = 2; + const double y = 3; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (TinyMatrix<2>{2 * x, 3 * y, 5 * (x - y), 2 * x - y})); + } + + SECTION(" R -> R^3x3") + { + std::string_view data = R"( +let f : R -> R^3x3, x -> (x, 2*x, x*x, 3*x, 2+x, x-1, x+0.5, 2*x-1, 1/x); + +let fx:R^3x3, fx = f(3); +)"; + + const double x = 3; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", + (TinyMatrix<3>{x, 2 * x, x * x, 3 * x, 2 + x, x - 1, x + 0.5, 2 * x - 1, + 1 / x})); + } + } + SECTION("multi-expression functions (using R^d)") { SECTION(" R -> R*R^1*R^2*R^3") @@ -479,6 +552,89 @@ let (x, x1, x2, x3):R*R^1*R^2*R^3, (x, x1, x2, x3) = f(y2, 0); } } + SECTION("multi-expression functions (using R^dxd)") + { + SECTION(" R -> R*R^1x1*R^2x2*R^3x3") + { + std::string_view data = R"( +let f : R -> R*R^1x1*R^2x2*R^3x3, x -> (x+1, 2*x, (x-2, x+2, 3, 2), (1, 0.5*x, x*x, x+1, 1/x, 2, x*x, 2*x-1, 3*x)); + +let (x, x11, x22, x33):R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(3); +)"; + + const double x = 3; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x + 1})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{2 * x})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{x - 2, x + 2, 3, 2})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", + (TinyMatrix<3>{1, 0.5 * x, x * x, x + 1, 1 / x, 2, x * x, 2 * x - 1, 3 * x})); + } + + SECTION(" R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3") + { + std::string_view data = R"( +let f : R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3, + (x22, x33) -> (x22[0,0]+x33[2,0], x33[1,2], (x33[0,1], x22[1,1], x22[0,0], x33[2,2]), x22[0,0]*x33); + +let y22:R^2x2, y22 = (2.3, 4.1, 6, -3); +let y33:R^3x3, y33 = (1.2, 1.3, 2.1, 3.2, -1.5, 2.3, -0.2, 3.1, -2.6); +let(x, x11, x22, x33) : R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(y22, y33); +)"; + + const TinyMatrix<2> x22{2.3, 4.1, 6, -3}; + const TinyMatrix<3> x33{1.2, 1.3, 2.1, 3.2, -1.5, 2.3, -0.2, 3.1, -2.6}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x22(0, 0) + x33(2, 0)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{x33(1, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{x33(0, 1), x22(1, 1), x22(0, 0), x33(2, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", (TinyMatrix<3>{x22(0, 0) * x33})); + } + + SECTION(" R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3 [with 0 as argument]") + { + std::string_view data = R"( +let f : R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3, + (x22, x33) -> (x22[0,0]+x33[2,1], x33[1,2], (x33[0,1], x22[1,0], x22[0,1], x33[2,2]), + (x22[1,0], x33[0,2]+x22[1,1], x33[2,2], + x33[2,0], x33[2,0]+x22[0,0], x33[1,1], + x33[2,1], x33[1,2]+x22[1,1], x33[0,0])); + +let y22:R^2x2, y22 = (2.3, 4.1, 3.1, 1.7); +let y33:R^3x3, y33 = (2.7, 3.1, 2.1, + 0.3, 1.2, 1.6, + 1.7, 2.2, 1.4); +let (x, x11, x22, x33) : R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(y22, y33); +)"; + + TinyMatrix<2> x22{2.3, 4.1, 3.1, 1.7}; + TinyMatrix<3> x33{2.7, 3.1, 2.1, 0.3, 1.2, 1.6, 1.7, 2.2, 1.4}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x22(0, 0) + x33(2, 1)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{x33(1, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{x33(0, 1), x22(1, 0), x22(0, 1), x33(2, 2)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", + (TinyMatrix<3>{x22(1, 0), x33(0, 2) + x22(1, 1), x33(2, 2), x33(2, 0), + x33(2, 0) + x22(0, 0), x33(1, 1), x33(2, 1), + x33(1, 2) + x22(1, 1), x33(0, 0)})); + } + + SECTION(" R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3 [with 0 in result]") + { + std::string_view data = R"( +let f : R^2x2*R^3x3 -> R*R^1x1*R^2x2*R^3x3, + (x22, x33) -> (x22[0,0]+x33[2,0], x33[1,1], 0, 0); + +let y22:R^2x2, y22 = (2.3, 4.1, 3.1, 1.7); +let (x, x11, x22, x33):R*R^1x1*R^2x2*R^3x3, (x, x11, x22, x33) = f(y22, 0); +)"; + + TinyMatrix<2> x22{2.3, 4.1, 3.1, 1.7}; + TinyMatrix<3> x33{zero}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", (double{x22(0, 0) + x33(2, 0)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x11", (TinyMatrix<1>{x33(1, 1)})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x22", (TinyMatrix<2>{zero})); + CHECK_FUNCTION_EVALUATION_RESULT(data, "x33", (TinyMatrix<3>{zero})); + } + } + SECTION("function composition") { SECTION("N -> N -> R") @@ -518,7 +674,6 @@ let x:R, x = pow(f(2)); SECTION("R -> R^2 -> R") { std::string_view data = R"( -import math; let f : R -> R^2, x -> (x+1, x*2); let g : R^2 -> R, x -> x[0] + x[1]; @@ -532,7 +687,6 @@ let x:R, x = g(f(3)); SECTION("R -> R^2*R^3 -> R") { std::string_view data = R"( -import math; let f : R -> R^2*R^3, x -> ((x+1, x*2), (6*x, 7-x, x/2.3)); let g : R^2*R^3 -> R, (x, y) -> x[0]*x[1] + y[0]*y[1]-y[2]; @@ -542,5 +696,37 @@ let x:R, x = g(f(3)); double x0 = 3; CHECK_FUNCTION_EVALUATION_RESULT(data, "x", double{(x0 + 1) * x0 * 2 + 6 * x0 * (7 - x0) - x0 / 2.3}); } + + SECTION("R -> R^2x2 -> R") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> (x+1, x*2, x-1, x); +let g : R^2x2 -> R, A -> A[0,0] + 2*A[1,1] + 3*A[0,1]+ A[1, 0]; + +let x:R, x = g(f(3)); +)"; + + const double x = 3; + const TinyMatrix<2> A{x + 1, x * 2, x - 1, x}; + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", double{A(0, 0) + 2 * A(1, 1) + 3 * A(0, 1) + A(1, 0)}); + } + + SECTION("R -> R^2x2*R^3x3 -> R") + { + std::string_view data = R"( +let f : R -> R^2x2*R^3x3, x -> ((x+1, x*2, x-1, x), (6*x, 7-x, x/2.3, -x, 2*x, x/2.5, x*x, 2*x, x)); +let g : R^2x2*R^3x3 -> R, (A22, A33) -> A22[0,0]*A22[1,1] + (A33[0,0]*A33[1,0]-A33[2,2])*A22[0,1]-A33[2,0]*A33[0,2]-A22[1,1]; + +let x:R, x = g(f(3)); +)"; + + const double x = 3; + const TinyMatrix<2> A22{x + 1, x * 2, x - 1, x}; + const TinyMatrix<3> A33{6 * x, 7 - x, x / 2.3, -x, 2 * x, x / 2.5, x * x, 2 * x, x}; + + CHECK_FUNCTION_EVALUATION_RESULT(data, "x", + double{A22(0, 0) * A22(1, 1) + (A33(0, 0) * A33(1, 0) - A33(2, 2)) * A22(0, 1) - + A33(2, 0) * A33(0, 2) - A22(1, 1)}); + } } } diff --git a/tests/test_FunctionTable.cpp b/tests/test_FunctionTable.cpp index c6ae9ab2f856457aaff70aaf2e35146661fde1b0..eae87ca79ac45e18da301fef72e1bd2fd35f90cf 100644 --- a/tests/test_FunctionTable.cpp +++ b/tests/test_FunctionTable.cpp @@ -11,14 +11,14 @@ TEST_CASE("FunctionTable", "[language]") SECTION("FunctionDescriptor") { std::unique_ptr domain_mapping_node = std::make_unique<ASTNode>(); - domain_mapping_node->m_data_type = ASTNodeDataType::unsigned_int_t; + domain_mapping_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); std::unique_ptr definition_node = std::make_unique<ASTNode>(); - definition_node->m_data_type = ASTNodeDataType::double_t; + definition_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); FunctionDescriptor f{"f", std::move(domain_mapping_node), std::move(definition_node)}; - REQUIRE(f.domainMappingNode().m_data_type == ASTNodeDataType::unsigned_int_t); - REQUIRE(f.definitionNode().m_data_type == ASTNodeDataType::double_t); + REQUIRE(f.domainMappingNode().m_data_type == ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>()); + REQUIRE(f.definitionNode().m_data_type == ASTNodeDataType::build<ASTNodeDataType::double_t>()); REQUIRE(domain_mapping_node == nullptr); REQUIRE(definition_node == nullptr); @@ -28,10 +28,10 @@ TEST_CASE("FunctionTable", "[language]") SECTION("uninitialized FunctionDescriptor") { std::unique_ptr domain_mapping_node = std::make_unique<ASTNode>(); - domain_mapping_node->m_data_type = ASTNodeDataType::unsigned_int_t; + domain_mapping_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); std::unique_ptr definition_node = std::make_unique<ASTNode>(); - definition_node->m_data_type = ASTNodeDataType::double_t; + definition_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); SECTION("nothing initialized") { @@ -44,7 +44,7 @@ TEST_CASE("FunctionTable", "[language]") { FunctionDescriptor f{"function", nullptr, std::move(definition_node)}; REQUIRE_THROWS_AS(f.domainMappingNode(), AssertError); - REQUIRE(f.definitionNode().m_data_type == ASTNodeDataType::double_t); + REQUIRE(f.definitionNode().m_data_type == ASTNodeDataType::build<ASTNodeDataType::double_t>()); REQUIRE(definition_node == nullptr); } @@ -52,7 +52,7 @@ TEST_CASE("FunctionTable", "[language]") { FunctionDescriptor f{"function", std::move(domain_mapping_node), nullptr}; REQUIRE_THROWS_AS(f.definitionNode(), AssertError); - REQUIRE(f.domainMappingNode().m_data_type == ASTNodeDataType::unsigned_int_t); + REQUIRE(f.domainMappingNode().m_data_type == ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>()); REQUIRE(domain_mapping_node == nullptr); } } @@ -63,10 +63,10 @@ TEST_CASE("FunctionTable", "[language]") REQUIRE(table.size() == 0); std::unique_ptr domain_mapping_node = std::make_unique<ASTNode>(); - domain_mapping_node->m_data_type = ASTNodeDataType::unsigned_int_t; + domain_mapping_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>(); std::unique_ptr definition_node = std::make_unique<ASTNode>(); - definition_node->m_data_type = ASTNodeDataType::double_t; + definition_node->m_data_type = ASTNodeDataType::build<ASTNodeDataType::double_t>(); size_t function_id = table.add(FunctionDescriptor{"function", std::move(domain_mapping_node), std::move(definition_node)}); @@ -82,8 +82,8 @@ TEST_CASE("FunctionTable", "[language]") auto& f = table[function_id]; REQUIRE(f.name() == "function"); - REQUIRE(f.domainMappingNode().m_data_type == ASTNodeDataType::unsigned_int_t); - REQUIRE(f.definitionNode().m_data_type == ASTNodeDataType::double_t); + REQUIRE(f.domainMappingNode().m_data_type == ASTNodeDataType::build<ASTNodeDataType::unsigned_int_t>()); + REQUIRE(f.definitionNode().m_data_type == ASTNodeDataType::build<ASTNodeDataType::double_t>()); const auto& const_f = const_table[function_id]; REQUIRE(const_f.name() == "function"); diff --git a/tests/test_IfProcessor.cpp b/tests/test_IfProcessor.cpp index ccade16253ae5b1024a97cc5ba038ec222bf5156..71fc766845d22090e319ea02f071e8707d72eed5 100644 --- a/tests/test_IfProcessor.cpp +++ b/tests/test_IfProcessor.cpp @@ -102,6 +102,30 @@ if(false) { CHECK_IF_PROCESSOR_RESULT(data, "i", 2ul); } + SECTION("simple if(true) with local variable") + { + std::string_view data = R"( +let i:N, i = 0; +if(true) { + let j:N, j = 1; + i = j; +} +)"; + CHECK_IF_PROCESSOR_RESULT(data, "i", 1ul); + } + + SECTION("simple if(false) with else local variable") + { + std::string_view data = R"( +let i:N, i = 0; +if(false) {} else { + let j:N, j = 1; + i = j; +} +)"; + CHECK_IF_PROCESSOR_RESULT(data, "i", 1ul); + } + SECTION("errors") { SECTION("bad test type") diff --git a/tests/test_IncDecExpressionProcessor.cpp b/tests/test_IncDecExpressionProcessor.cpp index ca13dbde2c33330515300b1fa19ba163a771c5f2..7525ca9ff5cf796ecbe3a183f570793b4ed614fb 100644 --- a/tests/test_IncDecExpressionProcessor.cpp +++ b/tests/test_IncDecExpressionProcessor.cpp @@ -40,6 +40,16 @@ REQUIRE(value == expected_value); \ } +#define CHECK_INCDEC_EXPRESSION_THROWS_WITH(data, error_message) \ + { \ + string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + ASTSymbolTableBuilder{*ast}; \ + \ + REQUIRE_THROWS_WITH(ASTNodeDataTypeBuilder{*ast}, error_message); \ + } + // clazy:excludeall=non-pod-global-static TEST_CASE("IncDecExpressionProcessor", "[language]") @@ -127,4 +137,51 @@ TEST_CASE("IncDecExpressionProcessor", "[language]") CHECK_INC_DEC_RESULT(R"(let r:R, r = 2; let s:R, s = r--;)", "s", 2.); } } + + SECTION("errors") + { + SECTION("undefined pre -- operator") + { + auto error_message = [](std::string type_name) { + return std::string{R"(undefined increment/decrement operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_INCDEC_EXPRESSION_THROWS_WITH(R"(--"foo";)", error_message("string")); + } + + SECTION("undefined pre ++ operator") + { + auto error_message = [](std::string type_name) { + return std::string{R"(undefined increment/decrement operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_INCDEC_EXPRESSION_THROWS_WITH(R"(++true;)", error_message("B")); + } + + SECTION("undefined post -- operator") + { + auto error_message = [](std::string type_name) { + return std::string{R"(undefined increment/decrement operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_INCDEC_EXPRESSION_THROWS_WITH(R"(true--;)", error_message("B")); + } + + SECTION("undefined post ++ operator") + { + auto error_message = [](std::string type_name) { + return std::string{R"(undefined increment/decrement operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_INCDEC_EXPRESSION_THROWS_WITH(R"("bar"++;)", error_message("string")); + } + } } diff --git a/tests/test_LinearSolver.cpp b/tests/test_LinearSolver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f56bb59ed1d6b7df2e81d8e4076f0ee23bb27ec6 --- /dev/null +++ b/tests/test_LinearSolver.cpp @@ -0,0 +1,527 @@ +#include <catch2/catch.hpp> + +#include <utils/pugs_config.hpp> + +#include <algebra/LinearSolver.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("LinearSolver", "[algebra]") +{ + SECTION("check has library") + { + LinearSolver linear_solver; + + REQUIRE(linear_solver.hasLibrary(LSLibrary::builtin) == true); + +#ifdef PUGS_HAS_PETSC + REQUIRE(linear_solver.hasLibrary(LSLibrary::petsc) == true); +#else // PUGS_HAS_PETSC + REQUIRE(linear_solver.hasLibrary(LSLibrary::petsc) == false); +#endif // PUGS_HAS_PETSC + } + + SECTION("check linear solver building") + { + LinearSolverOptions options; + + SECTION("builtin") + { + SECTION("builtin methods") + { + options.library() = LSLibrary::builtin; + options.precond() = LSPrecond::none; + + options.method() = LSMethod::cg; + REQUIRE_NOTHROW(LinearSolver{options}); + + options.method() = LSMethod::bicgstab; + REQUIRE_NOTHROW(LinearSolver{options}); + + options.method() = LSMethod::bicgstab2; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: BICGStab2 is not a builtin linear solver!"); + + options.method() = LSMethod::lu; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: LU is not a builtin linear solver!"); + + options.method() = LSMethod::choleski; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: Choleski is not a builtin linear solver!"); + + options.method() = LSMethod::gmres; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: GMRES is not a builtin linear solver!"); + } + + SECTION("builtin precond") + { + options.library() = LSLibrary::builtin; + options.method() = LSMethod::cg; + + options.precond() = LSPrecond::none; + REQUIRE_NOTHROW(LinearSolver{options}); + + options.precond() = LSPrecond::diagonal; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: diagonal is not a builtin preconditioner!"); + + options.precond() = LSPrecond::incomplete_LU; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: ILU is not a builtin preconditioner!"); + + options.precond() = LSPrecond::incomplete_choleski; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: ICholeski is not a builtin preconditioner!"); + + options.precond() = LSPrecond::amg; + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: AMG is not a builtin preconditioner!"); + } + } + + SECTION("PETSc") + { + LinearSolverOptions always_valid; + always_valid.library() = LSLibrary::builtin; + always_valid.method() = LSMethod::cg; + always_valid.precond() = LSPrecond::none; + + LinearSolver linear_solver{always_valid}; + + SECTION("PETSc methods") + { + options.library() = LSLibrary::petsc; + options.precond() = LSPrecond::none; + + options.method() = LSMethod::cg; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.method() = LSMethod::bicgstab; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.method() = LSMethod::bicgstab2; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.method() = LSMethod::lu; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.method() = LSMethod::choleski; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.method() = LSMethod::gmres; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + } + + SECTION("builtin precond") + { + options.library() = LSLibrary::petsc; + options.method() = LSMethod::cg; + + options.precond() = LSPrecond::none; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.precond() = LSPrecond::diagonal; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.precond() = LSPrecond::incomplete_LU; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.precond() = LSPrecond::incomplete_choleski; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + + options.precond() = LSPrecond::amg; + REQUIRE_NOTHROW(linear_solver.checkOptions(options)); + } + } + +#ifndef PUGS_HAS_PETSC + SECTION("not linked PETSc") + { + options.library() = LSLibrary::petsc; + options.method() = LSMethod::cg; + options.precond() = LSPrecond::none; + + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: PETSc is not linked to pugs. Cannot use it!"); + } +#endif // PUGS_HAS_PETSC + } + + SECTION("check linear solvers") + { + SECTION("symmetric system") + { + SparseMatrixDescriptor S{5}; + S(0, 0) = 2; + S(0, 1) = -1; + + S(1, 0) = -1; + S(1, 1) = 2; + S(1, 2) = -1; + + S(2, 1) = -1; + S(2, 2) = 2; + S(2, 3) = -1; + + S(3, 2) = -1; + S(3, 3) = 2; + S(3, 4) = -1; + + S(4, 3) = -1; + S(4, 4) = 2; + + CRSMatrix A{S}; + + Vector<const double> x_exact = [] { + Vector<double> y{5}; + y[0] = 1; + y[1] = 3; + y[2] = 2; + y[3] = 4; + y[4] = 5; + return y; + }(); + + Vector<double> b = A * x_exact; + + SECTION("builtin") + { + SECTION("CG no preconditioner") + { + LinearSolverOptions options; + options.library() = LSLibrary::builtin; + options.method() = LSMethod::cg; + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + } + + SECTION("PETSc") + { +#ifdef PUGS_HAS_PETSC + + SECTION("CG") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::cg; + options.precond() = LSPrecond::none; + options.verbose() = true; + + SECTION("CG no preconditioner") + { + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("CG Diagonal") + { + options.precond() = LSPrecond::diagonal; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("CG ICholeski") + { + options.precond() = LSPrecond::incomplete_choleski; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("CG AMG") + { + options.precond() = LSPrecond::amg; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + } + + SECTION("Choleski") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::choleski; + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + +#else // PUGS_HAS_PETSC + SECTION("PETSc not linked") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::cg; + options.precond() = LSPrecond::none; + + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: PETSc is not linked to pugs. Cannot use it!"); + } +#endif // PUGS_HAS_PETSC + } + } + + SECTION("none symmetric system") + { + SparseMatrixDescriptor S{5}; + S(0, 0) = 2; + S(0, 1) = -1; + + S(1, 0) = -0.2; + S(1, 1) = 2; + S(1, 2) = -1; + + S(2, 1) = -1; + S(2, 2) = 4; + S(2, 3) = -2; + + S(3, 2) = -1; + S(3, 3) = 2; + S(3, 4) = -0.1; + + S(4, 3) = 1; + S(4, 4) = 3; + + CRSMatrix A{S}; + + Vector<const double> x_exact = [] { + Vector<double> y{5}; + y[0] = 1; + y[1] = 3; + y[2] = 2; + y[3] = 4; + y[4] = 5; + return y; + }(); + + Vector<double> b = A * x_exact; + + SECTION("builtin") + { + SECTION("BICGStab no preconditioner") + { + LinearSolverOptions options; + options.library() = LSLibrary::builtin; + options.method() = LSMethod::bicgstab; + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + } + + SECTION("PETSc") + { +#ifdef PUGS_HAS_PETSC + + SECTION("BICGStab") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::bicgstab; + options.precond() = LSPrecond::none; + options.verbose() = true; + + SECTION("BICGStab no preconditioner") + { + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("BICGStab Diagonal") + { + options.precond() = LSPrecond::diagonal; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("BICGStab ILU") + { + options.precond() = LSPrecond::incomplete_LU; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + } + + SECTION("BICGStab2") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::bicgstab2; + options.precond() = LSPrecond::none; + + SECTION("BICGStab2 no preconditioner") + { + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("BICGStab2 Diagonal") + { + options.precond() = LSPrecond::diagonal; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + } + + SECTION("GMRES") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::gmres; + options.precond() = LSPrecond::none; + + SECTION("GMRES no preconditioner") + { + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("GMRES Diagonal") + { + options.precond() = LSPrecond::diagonal; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + + SECTION("GMRES ILU") + { + options.precond() = LSPrecond::incomplete_LU; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + } + + SECTION("LU") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::lu; + options.precond() = LSPrecond::none; + + Vector<double> x{5}; + x = 0; + + LinearSolver solver{options}; + + solver.solveLocalSystem(A, x, b); + Vector error = x - x_exact; + REQUIRE(std::sqrt((error, error)) < 1E-10 * std::sqrt((x_exact, x_exact))); + } + +#else // PUGS_HAS_PETSC + SECTION("PETSc not linked") + { + LinearSolverOptions options; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::cg; + options.precond() = LSPrecond::none; + + REQUIRE_THROWS_WITH(LinearSolver{options}, "error: PETSc is not linked to pugs. Cannot use it!"); + } +#endif // PUGS_HAS_PETSC + } + } + } +} diff --git a/tests/test_LinearSolverOptions.cpp b/tests/test_LinearSolverOptions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..494c36760d7b48c6e40c569423a100fed669d0d6 --- /dev/null +++ b/tests/test_LinearSolverOptions.cpp @@ -0,0 +1,177 @@ +#include <catch2/catch.hpp> + +#include <algebra/LinearSolverOptions.hpp> + +#include <sstream> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("LinearSolverOptions", "[algebra]") +{ + SECTION("print options") + { + SECTION("first set") + { + LinearSolverOptions options; + options.verbose() = true; + options.library() = LSLibrary::builtin; + options.method() = LSMethod::cg; + options.precond() = LSPrecond::incomplete_choleski; + options.epsilon() = 1E-3; + options.maximumIteration() = 100; + + std::stringstream os; + os << '\n' << options; + + std::stringstream expected_output; + expected_output << R"( + library: builtin + method : CG + precond: ICholeski + epsilon: )" << 1E-3 << R"( + maxiter: 100 + verbose: true +)"; + + REQUIRE(os.str() == expected_output.str()); + } + + SECTION("second set") + { + LinearSolverOptions options; + + options.verbose() = false; + options.library() = LSLibrary::petsc; + options.method() = LSMethod::gmres; + options.precond() = LSPrecond::incomplete_LU; + options.epsilon() = 1E-6; + options.maximumIteration() = 200; + + std::stringstream os; + os << '\n' << options; + + std::stringstream expected_output; + expected_output << R"( + library: PETSc + method : GMRES + precond: ILU + epsilon: )" << 1E-6 << R"( + maxiter: 200 + verbose: false +)"; + + REQUIRE(os.str() == expected_output.str()); + } + } + + SECTION("library name") + { + REQUIRE(name(LSLibrary::builtin) == "builtin"); + REQUIRE(name(LSLibrary::petsc) == "PETSc"); + REQUIRE_THROWS_WITH(name(LSLibrary::LS__end), "unexpected error: Linear system library name is not defined!"); + } + + SECTION("method name") + { + REQUIRE(name(LSMethod::cg) == "CG"); + REQUIRE(name(LSMethod::bicgstab) == "BICGStab"); + REQUIRE(name(LSMethod::bicgstab2) == "BICGStab2"); + REQUIRE(name(LSMethod::gmres) == "GMRES"); + REQUIRE(name(LSMethod::lu) == "LU"); + REQUIRE(name(LSMethod::choleski) == "Choleski"); + REQUIRE_THROWS_WITH(name(LSMethod::LS__end), "unexpected error: Linear system method name is not defined!"); + } + + SECTION("precond name") + { + REQUIRE(name(LSPrecond::none) == "none"); + REQUIRE(name(LSPrecond::diagonal) == "diagonal"); + REQUIRE(name(LSPrecond::incomplete_choleski) == "ICholeski"); + REQUIRE(name(LSPrecond::incomplete_LU) == "ILU"); + REQUIRE(name(LSPrecond::amg) == "AMG"); + REQUIRE_THROWS_WITH(name(LSPrecond::LS__end), + "unexpected error: Linear system preconditioner name is not defined!"); + } + + SECTION("library from name") + { + REQUIRE(LSLibrary::builtin == getLSEnumFromName<LSLibrary>("builtin")); + REQUIRE(LSLibrary::petsc == getLSEnumFromName<LSLibrary>("PETSc")); + + REQUIRE_THROWS_WITH(getLSEnumFromName<LSLibrary>("__invalid_lib"), + "error: could not find '__invalid_lib' associate type!"); + } + + SECTION("method from name") + { + REQUIRE(LSMethod::cg == getLSEnumFromName<LSMethod>("CG")); + REQUIRE(LSMethod::bicgstab == getLSEnumFromName<LSMethod>("BICGStab")); + REQUIRE(LSMethod::bicgstab2 == getLSEnumFromName<LSMethod>("BICGStab2")); + REQUIRE(LSMethod::lu == getLSEnumFromName<LSMethod>("LU")); + REQUIRE(LSMethod::choleski == getLSEnumFromName<LSMethod>("Choleski")); + REQUIRE(LSMethod::gmres == getLSEnumFromName<LSMethod>("GMRES")); + + REQUIRE_THROWS_WITH(getLSEnumFromName<LSMethod>("__invalid_method"), + "error: could not find '__invalid_method' associate type!"); + } + + SECTION("precond from name") + { + REQUIRE(LSPrecond::none == getLSEnumFromName<LSPrecond>("none")); + REQUIRE(LSPrecond::diagonal == getLSEnumFromName<LSPrecond>("diagonal")); + REQUIRE(LSPrecond::incomplete_choleski == getLSEnumFromName<LSPrecond>("ICholeski")); + REQUIRE(LSPrecond::incomplete_LU == getLSEnumFromName<LSPrecond>("ILU")); + + REQUIRE_THROWS_WITH(getLSEnumFromName<LSPrecond>("__invalid_precond"), + "error: could not find '__invalid_precond' associate type!"); + } + + SECTION("library list") + { + std::stringstream os; + os << '\n'; + printLSEnumListNames<LSLibrary>(os); + + const std::string library_list = R"( + - builtin + - PETSc +)"; + + REQUIRE(os.str() == library_list); + } + + SECTION("method list") + { + std::stringstream os; + os << '\n'; + printLSEnumListNames<LSMethod>(os); + + const std::string library_list = R"( + - CG + - BICGStab + - BICGStab2 + - GMRES + - LU + - Choleski +)"; + + REQUIRE(os.str() == library_list); + } + + SECTION("precond list") + { + std::stringstream os; + os << '\n'; + printLSEnumListNames<LSPrecond>(os); + + const std::string library_list = R"( + - none + - diagonal + - ICholeski + - ILU + - AMG +)"; + + REQUIRE(os.str() == library_list); + } +} diff --git a/tests/test_ListAffectationProcessor.cpp b/tests/test_ListAffectationProcessor.cpp index e0f17e7f593f98a15457686e9e38b0bccfa34938..0dc1b8d77c0e6dc23cb2681c30e8006d9e939f86 100644 --- a/tests/test_ListAffectationProcessor.cpp +++ b/tests/test_ListAffectationProcessor.cpp @@ -75,12 +75,16 @@ TEST_CASE("ListAffectationProcessor", "[language]") { SECTION("ListAffectations") { - SECTION("R*R^2*string") + SECTION("R*R^2*R^2x2*string") { - CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "x", double{1.2}); - CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "u", + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "x", + double{1.2}); + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "u", (TinyVector<2>{2, 3})); - CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "s", std::string{"foo"}); + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "A", + (TinyMatrix<2>{4, 3, 2, 1})); + CHECK_AFFECTATION_RESULT(R"(let (x,u,A,s): R*R^2*R^2x2*string, (x,u,A,s) = (1.2, (2,3), (4,3,2,1), "foo");)", "s", + std::string{"foo"}); } SECTION("compound with string conversion") @@ -89,19 +93,19 @@ TEST_CASE("ListAffectationProcessor", "[language]") std::to_string(double{3})); { std::ostringstream os; - os << TinyVector<1>{7} << std::ends; + os << TinyVector<1>{7}; CHECK_AFFECTATION_RESULT(R"(let v:R^1, v = 7; let (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), v);)", "s", os.str()); } { std::ostringstream os; - os << TinyVector<2>{6, 3} << std::ends; + os << TinyVector<2>{6, 3}; CHECK_AFFECTATION_RESULT(R"(let v: R^2, v = (6,3); let (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), v);)", "s", os.str()); } { std::ostringstream os; - os << TinyVector<3>{1, 2, 3} << std::ends; + os << TinyVector<3>{1, 2, 3}; CHECK_AFFECTATION_RESULT(R"(let v:R^3, v = (1,2,3); let (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), v);)", "s", os.str()); } @@ -114,11 +118,28 @@ TEST_CASE("ListAffectationProcessor", "[language]") CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3*R^2*R^1, (x,y,z) = (0,0,0);)", "z", (TinyVector<1>{zero})); } + SECTION("compound R^dxd from '0'") + { + CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3x3*R^2x2*R^1x1, (x,y,z) = (0,0,0);)", "x", (TinyMatrix<3>{zero})); + CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3x3*R^2x2*R^1x1, (x,y,z) = (0,0,0);)", "y", (TinyMatrix<2>{zero})); + CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3x3*R^2x2*R^1x1, (x,y,z) = (0,0,0);)", "z", (TinyMatrix<1>{zero})); + } + SECTION("compound with subscript values") { CHECK_AFFECTATION_RESULT(R"(let x:R^3; (x[0], x[2], x[1]) = (4, 6, 5);)", "x", (TinyVector<3>{4, 5, 6})); CHECK_AFFECTATION_RESULT(R"(let x:R^2; (x[1], x[0]) = (3, 6);)", "x", (TinyVector<2>{6, 3})); CHECK_AFFECTATION_RESULT(R"(let x:R^1; let y:R; (y, x[0]) = (4, 2.3);)", "x", (TinyVector<1>{2.3})); } + + SECTION("compound with subscript values") + { + CHECK_AFFECTATION_RESULT( + R"(let x:R^3x3; (x[0,0], x[1,0], x[1,2], x[2,0], x[0,1], x[0,2], x[1,1], x[2,1], x[2,2]) = (1, 4, 6, 7, 2, 3, 5, 8, 9);)", + "x", (TinyMatrix<3>{1, 2, 3, 4, 5, 6, 7, 8, 9})); + CHECK_AFFECTATION_RESULT(R"(let x:R^2x2; (x[1,1], x[0,0], x[1,0], x[0,1]) = (3, 6, 2, 4);)", "x", + (TinyMatrix<2>{6, 4, 2, 3})); + CHECK_AFFECTATION_RESULT(R"(let x:R^1x1; let y:R; (y, x[0,0]) = (4, 2.3);)", "x", (TinyMatrix<1>{2.3})); + } } } diff --git a/tests/mpi_test_Messenger.cpp b/tests/test_Messenger.cpp similarity index 98% rename from tests/mpi_test_Messenger.cpp rename to tests/test_Messenger.cpp index 19d89851c82020c0f10e4eaeac919595059e8fab..022a5c07e4b16db81dc69ec236a1482d1138d6d1 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/test_Messenger.cpp @@ -471,4 +471,11 @@ TEST_CASE("Messenger", "[mpi]") std::remove("barrier_test"); } + + SECTION("errors") + { + int argc = 0; + char** argv = nullptr; + REQUIRE_THROWS_WITH((parallel::Messenger::create(argc, argv)), "unexpected error: Messenger already created"); + } } diff --git a/tests/test_ParseError.cpp b/tests/test_ParseError.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fb8eb5dca61f1098592eafd1d84d51f2e953bc0 --- /dev/null +++ b/tests/test_ParseError.cpp @@ -0,0 +1,39 @@ +#include <catch2/catch.hpp> + +#include <language/utils/ParseError.hpp> + +#include <string> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("ParseError", "[language]") +{ + SECTION("single position") + { + const std::string source = R"( +a first line +a second line +)"; + TAO_PEGTL_NAMESPACE::internal::iterator i(&source[0], 3, 1, 2); + TAO_PEGTL_NAMESPACE::position p{i, source}; + ParseError parse_error("error message", p); + REQUIRE(parse_error.positions() == std::vector{p}); + REQUIRE(parse_error.what() == std::string{"error message"}); + } + + SECTION("position list") + { + const std::string source = R"( +a first line +a second line +)"; + TAO_PEGTL_NAMESPACE::internal::iterator i0(&source[0], 3, 1, 2); + TAO_PEGTL_NAMESPACE::position p0{i0, source}; + TAO_PEGTL_NAMESPACE::internal::iterator i1(&source[0], 4, 1, 3); + TAO_PEGTL_NAMESPACE::position p1{i1, source}; + + ParseError parse_error("error message", std::vector{p0, p1}); + REQUIRE(parse_error.positions() == std::vector{p0, p1}); + REQUIRE(parse_error.what() == std::string{"error message"}); + } +} diff --git a/tests/test_Partitioner.cpp b/tests/test_Partitioner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9af1fc6c6f24be33aa5363696f36cab5909e6e7d --- /dev/null +++ b/tests/test_Partitioner.cpp @@ -0,0 +1,86 @@ +#include <catch2/catch.hpp> + +#include <utils/Messenger.hpp> +#include <utils/Partitioner.hpp> + +#include <set> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("Partitioner", "[utils]") +{ + SECTION("one graph split to all") + { + Partitioner partitioner; + + std::vector<int> entries_vector; + std::vector<int> neighbors_vector; + + entries_vector.push_back(neighbors_vector.size()); + + if (parallel::rank() == 0) { + neighbors_vector.push_back(1); + neighbors_vector.push_back(2); + neighbors_vector.push_back(4); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(0); + neighbors_vector.push_back(3); + neighbors_vector.push_back(5); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(0); + neighbors_vector.push_back(2); + neighbors_vector.push_back(5); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(0); + neighbors_vector.push_back(2); + neighbors_vector.push_back(5); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(3); + neighbors_vector.push_back(5); + neighbors_vector.push_back(7); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(3); + neighbors_vector.push_back(5); + neighbors_vector.push_back(6); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(5); + neighbors_vector.push_back(6); + neighbors_vector.push_back(7); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(3); + neighbors_vector.push_back(5); + neighbors_vector.push_back(7); + entries_vector.push_back(neighbors_vector.size()); + + neighbors_vector.push_back(4); + neighbors_vector.push_back(6); + neighbors_vector.push_back(7); + entries_vector.push_back(neighbors_vector.size()); + } + + Array<int> entries = convert_to_array(entries_vector); + Array<int> neighbors = convert_to_array(neighbors_vector); + + CRSGraph graph{entries, neighbors}; + + Array<int> partitioned = partitioner.partition(graph); + + REQUIRE((partitioned.size() + 1) == entries.size()); + + if (parallel::rank() == 0) { + std::set<int> assigned_ranks; + for (size_t i = 0; i < partitioned.size(); ++i) { + assigned_ranks.insert(partitioned[i]); + } + + REQUIRE(assigned_ranks.size() == parallel::size()); + } + } +} diff --git a/tests/test_PugsFunctionAdapter.cpp b/tests/test_PugsFunctionAdapter.cpp index 45229f8b599a1517fcd4c2182de563601643250d..4b04c7f5e088c1bdec314c91414fac4924b8bd6b 100644 --- a/tests/test_PugsFunctionAdapter.cpp +++ b/tests/test_PugsFunctionAdapter.cpp @@ -77,11 +77,17 @@ let ZplusZ: Z*Z -> Z, (x,y) -> x+y; let RplusR: R*R -> R, (x,y) -> x+y; let RRtoR2: R*R -> R^2, (x,y) -> (x+y, x-y); let R3times2: R^3 -> R^3, x -> 2*x; +let R33times2: R^3x3 -> R^3x3, x -> 2*x; let BtoR1: B -> R^1, b -> not b; +let BtoR11: B -> R^1x1, b -> not b; let NtoR1: N -> R^1, n -> n*n; +let NtoR11: N -> R^1x1, n -> n*n; let ZtoR1: Z -> R^1, z -> -z; +let ZtoR11: Z -> R^1x1, z -> -z; let RtoR1: R -> R^1, x -> x*x; +let RtoR11: R -> R^1x1, x -> x*x; let R3toR3zero: R^3 -> R^3, x -> 0; +let R33toR33zero: R^3x3 -> R^3x3, x -> 0; )"; string_input input{data, "test.pgs"}; @@ -194,6 +200,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == 2 * x); } + { + auto [i_symbol, found] = symbol_table->find("R33times2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyMatrix<3> x{2, 3, 4, 1, 6, 5, 9, 7, 8}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<3> result = tests_adapter::TestBinary<TinyMatrix<3>(TinyMatrix<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == 2 * x); + } + { auto [i_symbol, found] = symbol_table->find("BtoR1", position); REQUIRE(found); @@ -218,6 +237,30 @@ let R3toR3zero: R^3 -> R^3, x -> 0; } } + { + auto [i_symbol, found] = symbol_table->find("BtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + { + const bool b = true; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + + { + const bool b = false; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + } + { auto [i_symbol, found] = symbol_table->find("NtoR1", position); REQUIRE(found); @@ -231,6 +274,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == n * n); } + { + auto [i_symbol, found] = symbol_table->find("NtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const uint64_t n = 4; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(uint64_t)>::one_arg(function_symbol_id, n); + + REQUIRE(result == n * n); + } + { auto [i_symbol, found] = symbol_table->find("ZtoR1", position); REQUIRE(found); @@ -244,6 +300,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == -z); } + { + auto [i_symbol, found] = symbol_table->find("ZtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const int64_t z = 3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(int64_t)>::one_arg(function_symbol_id, z); + + REQUIRE(result == -z); + } + { auto [i_symbol, found] = symbol_table->find("RtoR1", position); REQUIRE(found); @@ -257,6 +326,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == x * x); } + { + auto [i_symbol, found] = symbol_table->find("RtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 3.3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(double)>::one_arg(function_symbol_id, x); + + REQUIRE(result == x * x); + } + { auto [i_symbol, found] = symbol_table->find("R3toR3zero", position); REQUIRE(found); @@ -269,17 +351,31 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == TinyVector<3>{0, 0, 0}); } + + { + auto [i_symbol, found] = symbol_table->find("R33toR33zero", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyMatrix<3> x{1, 0, 0, 0, 1, 0, 0, 0, 1}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<3> result = tests_adapter::TestBinary<TinyMatrix<3>(TinyMatrix<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == TinyMatrix<3>{0, 0, 0, 0, 0, 0, 0, 0, 0}); + } } SECTION("Errors calls") { std::string_view data = R"( let R1toR1: R^1 -> R^1, x -> x; -let R3toR3: R^3 -> R^3, x -> 1; +let R3toR3: R^3 -> R^3, x -> 0; let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z); let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]); let RtoNS: R -> N*string, x -> (1, "foo"); let RtoR: R -> R, x -> 2*x; +let R33toR22: R^3x3 -> R^2x2, x -> (x[0,0], x[0,1]+x[0,2], x[2,0]*x[1,1], x[2,1]+x[2,2]); )"; string_input input{data, "test.pgs"}; @@ -322,9 +418,9 @@ let RtoR: R -> R, x -> 2*x; const TinyVector<3> x{2, 1, 3}; FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); - REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x), + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<2>(TinyVector<3>)>::one_arg(function_symbol_id, x), "error: invalid function type\n" - "note: expecting R^3 -> R^3\n" + "note: expecting R^3 -> R^2\n" "note: provided function R3toR3: R^3 -> R^3"); } @@ -385,5 +481,19 @@ let RtoR: R -> R, x -> 2*x; "note: expecting R -> R^3\n" "note: provided function RtoR: R -> R"); } + + { + auto [i_symbol, found] = symbol_table->find("R33toR22", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<double(double)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R -> R\n" + "note: provided function R33toR22: R^3x3 -> R^2x2"); + } } } diff --git a/tests/test_PugsUtils.cpp b/tests/test_PugsUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b966ae8f0c0bab490c9f7e3e4441ffdd329aeec1 --- /dev/null +++ b/tests/test_PugsUtils.cpp @@ -0,0 +1,110 @@ +#include <catch2/catch.hpp> + +#include <utils/BuildInfo.hpp> +#include <utils/PugsUtils.hpp> +#include <utils/RevisionInfo.hpp> +#include <utils/pugs_build_info.hpp> + +#include <rang.hpp> + +#include <string> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("PugsUtils", "[utils]") +{ + SECTION("checking infos") + { + const std::string pugs_version = [] { + std::stringstream os; + + os << "pugs version: " << rang::style::bold << RevisionInfo::version() << rang::style::reset << '\n'; + + os << "-------------------- " << rang::fg::green << "git info" << rang::fg::reset << " -------------------------" + << '\n'; + os << "tag: " << rang::style::bold << RevisionInfo::gitTag() << rang::style::reset << '\n'; + os << "HEAD: " << rang::style::bold << RevisionInfo::gitHead() << rang::style::reset << '\n'; + os << "hash: " << rang::style::bold << RevisionInfo::gitHash() << rang::style::reset << " ("; + + if (RevisionInfo::gitIsClean()) { + os << rang::fgB::green << "clean" << rang::fg::reset; + } else { + os << rang::fgB::red << "dirty" << rang::fg::reset; + } + os << ")\n"; + os << "-------------------------------------------------------"; + + return os.str(); + }(); + + REQUIRE(pugsVersion() == pugs_version); + + const std::string build_info = [] { + std::ostringstream os; + + os << "-------------------- " << rang::fg::green << "build info" << rang::fg::reset << " -----------------------" + << '\n'; + os << "type: " << rang::style::bold << BuildInfo::type() << rang::style::reset << '\n'; + os << "compiler: " << rang::style::bold << BuildInfo::compiler() << rang::style::reset << '\n'; + os << "kokkos: " << rang::style::bold << BuildInfo::kokkosDevices() << rang::style::reset << '\n'; + os << "MPI: " << rang::style::bold << BuildInfo::mpiLibrary() << rang::style::reset << '\n'; + os << "PETSc: " << rang::style::bold << BuildInfo::petscLibrary() << rang::style::reset << '\n'; + os << "-------------------------------------------------------"; + + return os.str(); + }(); + + REQUIRE(pugsBuildInfo() == build_info); + } + + SECTION("checking OMP environment setting") + { + if constexpr (std::string_view{PUGS_BUILD_KOKKOS_DEVICES} == std::string_view{"OpenMP"}) { + const std::string saved_omp_proc_bind = []() { + char* value = getenv("OMP_PROC_BIND"); + if (value != nullptr) { + return std::string{value}; + } else { + return std::string{}; + } + }(); + + const std::string saved_omp_places = []() { + char* value = getenv("OMP_PLACES"); + if (value != nullptr) { + return std::string{value}; + } else { + return std::string{}; + } + }(); + + unsetenv("OMP_PROC_BIND"); + unsetenv("OMP_PLACES"); + + setDefaultOMPEnvironment(); + REQUIRE(std::string{getenv("OMP_PROC_BIND")} == std::string{"spread"}); + REQUIRE(std::string{getenv("OMP_PLACES")} == std::string{"threads"}); + + unsetenv("OMP_PROC_BIND"); + unsetenv("OMP_PLACES"); + + setenv("OMP_PROC_BIND", "foo", 1); + setenv("OMP_PLACES", "bar", 1); + + setDefaultOMPEnvironment(); + REQUIRE(std::string{getenv("OMP_PROC_BIND")} == std::string{"foo"}); + REQUIRE(std::string{getenv("OMP_PLACES")} == std::string{"bar"}); + + unsetenv("OMP_PROC_BIND"); + unsetenv("OMP_PLACES"); + + if (saved_omp_proc_bind.size() != 0) { + setenv("OMP_PROC_BIND", saved_omp_proc_bind.c_str(), 1); + } + + if (saved_omp_places.size() != 0) { + setenv("OMP_PLACES", saved_omp_places.c_str(), 1); + } + } + } +} diff --git a/tests/test_SparseMatrixDescriptor.cpp b/tests/test_SparseMatrixDescriptor.cpp index cd66380b1d1beb6bcb9ba1a9c0d9c495947a3235..8616a41da0ff8a60857ef314c74e04becf6c651f 100644 --- a/tests/test_SparseMatrixDescriptor.cpp +++ b/tests/test_SparseMatrixDescriptor.cpp @@ -68,18 +68,18 @@ TEST_CASE("SparseMatrixDescriptor", "[algebra]") S(4, 1) = 1; S(4, 4) = -2; - REQUIRE(S.row(0).numberOfValues() == 1); + REQUIRE(S.row(0).numberOfValues() == 2); REQUIRE(S.row(1).numberOfValues() == 2); REQUIRE(S.row(2).numberOfValues() == 1); - REQUIRE(S.row(3).numberOfValues() == 1); + REQUIRE(S.row(3).numberOfValues() == 2); REQUIRE(S.row(4).numberOfValues() == 2); const auto& const_S = S; - REQUIRE(const_S.row(0).numberOfValues() == 1); + REQUIRE(const_S.row(0).numberOfValues() == 2); REQUIRE(const_S.row(1).numberOfValues() == 2); REQUIRE(const_S.row(2).numberOfValues() == 1); - REQUIRE(const_S.row(3).numberOfValues() == 1); + REQUIRE(const_S.row(3).numberOfValues() == 2); REQUIRE(const_S.row(4).numberOfValues() == 2); #ifndef NDEBUG @@ -126,13 +126,14 @@ TEST_CASE("SparseMatrixDescriptor", "[algebra]") const auto graph = S.graphVector(); REQUIRE(graph.size() == S.numberOfRows()); - REQUIRE(graph[0].size() == 1); + REQUIRE(graph[0].size() == 2); REQUIRE(graph[1].size() == 2); REQUIRE(graph[2].size() == 1); REQUIRE(graph[3].size() == 2); REQUIRE(graph[4].size() == 2); - REQUIRE(graph[0][0] == 2); + REQUIRE(graph[0][0] == 0); + REQUIRE(graph[0][1] == 2); REQUIRE(graph[1][0] == 1); REQUIRE(graph[1][1] == 2); REQUIRE(graph[2][0] == 2); @@ -157,16 +158,17 @@ TEST_CASE("SparseMatrixDescriptor", "[algebra]") const auto value_array = S.valueArray(); - REQUIRE(value_array.size() == 8); - - REQUIRE(value_array[0] == 5); - REQUIRE(value_array[1] == 1); - REQUIRE(value_array[2] == 11); - REQUIRE(value_array[3] == 4); - REQUIRE(value_array[4] == -3); - REQUIRE(value_array[5] == 5); - REQUIRE(value_array[6] == 1); - REQUIRE(value_array[7] == -2); + REQUIRE(value_array.size() == 9); + + REQUIRE(value_array[0] == 0); + REQUIRE(value_array[1] == 5); + REQUIRE(value_array[2] == 1); + REQUIRE(value_array[3] == 11); + REQUIRE(value_array[4] == 4); + REQUIRE(value_array[5] == -3); + REQUIRE(value_array[6] == 5); + REQUIRE(value_array[7] == 1); + REQUIRE(value_array[8] == -2); } SECTION("output") @@ -186,7 +188,7 @@ TEST_CASE("SparseMatrixDescriptor", "[algebra]") output << '\n' << S; std::string expected_output = R"( -0 | 2:5 +0 | 0:0 2:5 1 | 1:1 2:11 2 | 2:4 3 | 1:-3 3:5 diff --git a/tests/test_SymbolTable.cpp b/tests/test_SymbolTable.cpp index bb80c4999781e59f3585f1ff64eba3ac3b71bcac..fed694f752f4b87ce79a88ae97547800bb4d35c2 100644 --- a/tests/test_SymbolTable.cpp +++ b/tests/test_SymbolTable.cpp @@ -67,8 +67,8 @@ TEST_CASE("SymbolTable", "[language]") attributes_a.setIsInitialized(); REQUIRE(attributes_a.isInitialized()); - attributes_a.setDataType(ASTNodeDataType::double_t); - REQUIRE(attributes_a.dataType() == ASTNodeDataType::double_t); + attributes_a.setDataType(ASTNodeDataType::build<ASTNodeDataType::double_t>()); + REQUIRE(attributes_a.dataType() == ASTNodeDataType::build<ASTNodeDataType::double_t>()); attributes_a.value() = 2.3; @@ -166,8 +166,8 @@ TEST_CASE("SymbolTable", "[language]") attributes_a.setIsInitialized(); REQUIRE(attributes_a.isInitialized()); - attributes_a.setDataType(ASTNodeDataType::function_t); - REQUIRE(attributes_a.dataType() == ASTNodeDataType::function_t); + attributes_a.setDataType(ASTNodeDataType::build<ASTNodeDataType::function_t>()); + REQUIRE(attributes_a.dataType() == ASTNodeDataType::build<ASTNodeDataType::function_t>()); attributes_a.value() = static_cast<uint64_t>(2); diff --git a/tests/test_Timer.cpp b/tests/test_Timer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf85f15653eec2ba4bd6914873e6305cedaf4829 --- /dev/null +++ b/tests/test_Timer.cpp @@ -0,0 +1,82 @@ +#include <catch2/catch.hpp> + +#include <utils/Timer.hpp> + +#include <chrono> +#include <sstream> +#include <thread> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("Timer", "[utils]") +{ + SECTION("auto start") + { + Timer t; + + REQUIRE(t.status() == Timer::Status::running); + + double seconds = t.seconds(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + REQUIRE(t.seconds() > seconds); + + t.start(); + seconds = t.seconds(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + REQUIRE(t.status() == Timer::Status::running); + + REQUIRE(t.seconds() > seconds); + } + + SECTION("pause/start") + { + Timer t1; + + REQUIRE(t1.status() == Timer::Status::running); + t1.pause(); + + const double seconds = t1.seconds(); + REQUIRE(t1.status() == Timer::Status::paused); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + REQUIRE(t1.seconds() == seconds); + + std::stringstream os1; + os1 << t1; + + Timer t2 = t1; + std::stringstream os2; + os2 << t2.seconds() << 's'; + + REQUIRE(os1.str() == os2.str()); + + REQUIRE(t1.seconds() == t2.seconds()); + t1.start(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + REQUIRE(t1.status() == Timer::Status::running); + REQUIRE(t1.seconds() > t2.seconds()); + + t2.reset(); + REQUIRE(t2.status() == Timer::Status::paused); + REQUIRE(t2.seconds() == 0); + } + + SECTION("stop/start") + { + Timer t; + REQUIRE(t.status() == Timer::Status::running); + + std::this_thread::sleep_for(std::chrono::microseconds(5)); + const double seconds = t.seconds(); + + REQUIRE(seconds > 0); + + t.stop(); + REQUIRE(t.status() == Timer::Status::stopped); + REQUIRE(t.seconds() == 0); + + t.start(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + REQUIRE(t.status() == Timer::Status::running); + REQUIRE(t.seconds() > 0); + } +} diff --git a/tests/test_TinyMatrix.cpp b/tests/test_TinyMatrix.cpp index e3245e79612532f0919d142fe0a349ce69828b4d..fb113626feb581a2bac7f65460eb3f3262cb9a50 100644 --- a/tests/test_TinyMatrix.cpp +++ b/tests/test_TinyMatrix.cpp @@ -161,9 +161,12 @@ TEST_CASE("TinyMatrix", "[algebra]") { REQUIRE(det(TinyMatrix<1, int>(6)) == 6); REQUIRE(det(TinyMatrix<2, int>(3, 1, -3, 6)) == 21); + REQUIRE(det(TinyMatrix<3, int>(1, 1, 1, 1, 2, 1, 2, 1, 3)) == 1); REQUIRE(det(B) == -1444); REQUIRE(det(TinyMatrix<4, double>(1, 2.3, 7, -6.2, 3, 4, 9, 1, 4.1, 5, 2, -3, 2, 27, 3, 17.5)) == Approx(6661.455).epsilon(1E-14)); + + REQUIRE(det(TinyMatrix<4, double>(1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 2, 0, 0, 2, 2)) == 0); } SECTION("checking for inverse calculations") @@ -220,10 +223,25 @@ TEST_CASE("TinyMatrix", "[algebra]") } } + SECTION("checking for sizes") + { + REQUIRE(TinyMatrix<1>{}.nbRows() == 1); + REQUIRE(TinyMatrix<1>{}.nbColumns() == 1); + REQUIRE(TinyMatrix<1>{}.dimension() == 1); + + REQUIRE(TinyMatrix<2>{}.nbRows() == 2); + REQUIRE(TinyMatrix<2>{}.nbColumns() == 2); + REQUIRE(TinyMatrix<2>{}.dimension() == 4); + + REQUIRE(TinyMatrix<3>{}.nbRows() == 3); + REQUIRE(TinyMatrix<3>{}.nbColumns() == 3); + REQUIRE(TinyMatrix<3>{}.dimension() == 9); + } + SECTION("checking for matrices output") { REQUIRE(Catch::Detail::stringify(A) == "[(1,2,3)(4,5,6)(7,8,9)]"); - REQUIRE(Catch::Detail::stringify(TinyMatrix<1, int>(7)) == "7"); + REQUIRE(Catch::Detail::stringify(TinyMatrix<1, int>(7)) == "[(7)]"); } #ifndef NDEBUG diff --git a/tests/test_UnaryExpressionProcessor.cpp b/tests/test_UnaryExpressionProcessor.cpp index e933902adaefd968cc3036c8bbf007b29d054f39..b309d6dd13c0791a909ab30e164325ac7d365e00 100644 --- a/tests/test_UnaryExpressionProcessor.cpp +++ b/tests/test_UnaryExpressionProcessor.cpp @@ -68,12 +68,29 @@ TEST_CASE("UnaryExpressionProcessor", "[language]") SECTION("errors") { - SECTION("bad implicit conversions") + SECTION("undefined not operator") { - CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(let n:N, n = 0; not n;)", "invalid implicit conversion: N -> B"); - CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(not 1;)", "invalid implicit conversion: Z -> B"); - CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(not 1.3;)", "invalid implicit conversion: R -> B"); - CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(not "foo";)", "invalid implicit conversion: string -> B"); + auto error_message = [](std::string type_name) { + return std::string{R"(undefined unary operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(let n:N, n = 0; not n;)", error_message("N")); + CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(not 1;)", error_message("Z")); + CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(not 1.3;)", error_message("R")); + CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(not "foo";)", error_message("string")); + } + + SECTION("undefined unary minus operator") + { + auto error_message = [](std::string type_name) { + return std::string{R"(undefined unary operator +note: unexpected operand type )"} + + type_name; + }; + + CHECK_UNARY_EXPRESSION_THROWS_WITH(R"(-"foo";)", error_message("string")); } } } diff --git a/tests/test_main.cpp b/tests/test_main.cpp index eba6d05d0ea0cc29b222ec92f4815a3820118c20..b2703a5950eeabe8eea33ba61cab4fc2e2f9157a 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -3,21 +3,61 @@ #include <Kokkos_Core.hpp> +#include <algebra/PETScWrapper.hpp> +#include <language/utils/OperatorRepository.hpp> +#include <mesh/DiamondDualConnectivityManager.hpp> +#include <mesh/DiamondDualMeshManager.hpp> +#include <mesh/MeshDataManager.hpp> +#include <mesh/SynchronizerManager.hpp> +#include <utils/Messenger.hpp> + +#include <MeshDataBaseForTests.hpp> + int main(int argc, char* argv[]) { + parallel::Messenger::create(argc, argv); Kokkos::initialize({4, -1, -1, true}); + PETScWrapper::initialize(argc, argv); + Catch::Session session; int result = session.applyCommandLine(argc, argv); if (result == 0) { - // Disable outputs from tested classes to the standard output - std::cout.setstate(std::ios::badbit); - result = session.run(); + const auto& config = session.config(); + if (config.listReporters() or config.listTags() or config.listTestNamesOnly() or config.listTests()) { + result = session.run(); + } else { + // Disable outputs from tested classes to the standard output + std::cout.setstate(std::ios::badbit); + + SynchronizerManager::create(); + MeshDataManager::create(); + DiamondDualConnectivityManager::create(); + DiamondDualMeshManager::create(); + + MeshDataBaseForTests::create(); + + OperatorRepository::create(); + + result = session.run(); + + OperatorRepository::destroy(); + + MeshDataBaseForTests::destroy(); + + DiamondDualMeshManager::destroy(); + DiamondDualConnectivityManager::destroy(); + MeshDataManager::destroy(); + SynchronizerManager::destroy(); + } } + PETScWrapper::finalize(); + Kokkos::finalize(); + parallel::Messenger::destroy(); return result; }