diff --git a/src/mesh/MeshFlatNodeBoundary.cpp b/src/mesh/MeshFlatNodeBoundary.cpp
index b7de42715eff88510119b54507ed32323365c817..459a827c84ead76ed23acdf28cbd8525923694da 100644
--- a/src/mesh/MeshFlatNodeBoundary.cpp
+++ b/src/mesh/MeshFlatNodeBoundary.cpp
@@ -87,57 +87,108 @@ MeshFlatNodeBoundary<2>::_getNormal(const Mesh<Connectivity<2>>& mesh)
 
 template <>
 TinyVector<3, double>
-MeshFlatNodeBoundary<3>::_getNormal(const Mesh<Connectivity<3>>& mesh)
+MeshFlatNodeBoundary<3>::_getFarestNode(const Mesh<Connectivity<3>>& mesh, const Rd& x0, const Rd& x1)
 {
-  using R3 = TinyVector<3, double>;
-
-  std::array<R3, 6> bounds = this->_getBounds(mesh);
+  const NodeValue<const Rd>& xr = mesh.xr();
+  const auto node_number        = mesh.connectivity().nodeNumber();
 
-  const R3& xmin = bounds[0];
-  const R3& ymin = bounds[1];
-  const R3& zmin = bounds[2];
-  const R3& xmax = bounds[3];
-  const R3& ymax = bounds[4];
-  const R3& zmax = bounds[5];
+  using NodeNumberType = std::remove_const_t<typename decltype(node_number)::data_type>;
 
-  const R3 u = xmax - xmin;
-  const R3 v = ymax - ymin;
-  const R3 w = zmax - zmin;
+  Rd t = x1 - x0;
+  t *= 1. / l2Norm(t);
 
-  const R3 uv        = crossProduct(u, v);
-  const double uv_l2 = dot(uv, uv);
+  double farest_distance       = 0;
+  Rd farest_x                  = zero;
+  NodeNumberType farest_number = std::numeric_limits<NodeNumberType>::max();
 
-  R3 normal        = uv;
-  double normal_l2 = uv_l2;
+  auto node_list = this->m_ref_node_list.list();
 
-  const R3 uw        = crossProduct(u, w);
-  const double uw_l2 = dot(uw, uw);
+  for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
+    const NodeId& node_id = node_list[i_node];
+    const Rd& x           = xr[node_id];
+    const double distance = l2Norm(crossProduct(t, x - x0));
 
-  if (uw_l2 > uv_l2) {
-    normal    = uw;
-    normal_l2 = uw_l2;
+    if ((distance > farest_distance) or ((distance == farest_distance) and (node_number[node_id] < farest_number))) {
+      farest_distance = distance;
+      farest_number   = node_number[node_id];
+      farest_x        = x;
+    }
   }
 
-  const R3 vw        = crossProduct(v, w);
-  const double vw_l2 = dot(vw, vw);
+  if (parallel::size()) {
+    Array<double> farest_distance_array       = parallel::allGather(farest_distance);
+    Array<Rd> farest_x_array                  = parallel::allGather(farest_x);
+    Array<NodeNumberType> farest_number_array = parallel::allGather(farest_number);
 
-  if (vw_l2 > normal_l2) {
-    normal    = vw;
-    normal_l2 = vw_l2;
+    Assert(farest_distance_array.size() == farest_x_array.size());
+    Assert(farest_distance_array.size() == farest_number_array.size());
+
+    for (size_t i = 0; i < farest_distance_array.size(); ++i) {
+      if ((farest_distance_array[i] > farest_distance) or
+          ((farest_distance_array[i] == farest_distance) and (farest_number_array[i] < farest_number))) {
+        farest_distance = farest_distance_array[i];
+        farest_number   = farest_number_array[i];
+        farest_x        = farest_x_array[i];
+      }
+    }
   }
 
-  if (normal_l2 == 0) {
+  return farest_x;
+}
+
+template <>
+TinyVector<3, double>
+MeshFlatNodeBoundary<3>::_getNormal(const Mesh<Connectivity<3>>& mesh)
+{
+  using R3 = TinyVector<3, double>;
+
+  std::array<R3, 2> diagonal = [](const std::array<R3, 6>& bounds) {
+    size_t max_i      = 0;
+    size_t max_j      = 0;
+    double max_length = 0;
+
+    for (size_t i = 0; i < bounds.size(); ++i) {
+      for (size_t j = i + 1; j < bounds.size(); ++j) {
+        double length = l2Norm(bounds[i] - bounds[j]);
+        if (length > max_length) {
+          max_i      = i;
+          max_j      = j;
+          max_length = length;
+        }
+      }
+    }
+
+    return std::array<R3, 2>{bounds[max_i], bounds[max_j]};
+  }(this->_getBounds(mesh));
+
+  const R3& x0 = diagonal[0];
+  const R3& x1 = diagonal[1];
+
+  if (x0 == x1) {
     std::ostringstream ost;
     ost << "invalid boundary \"" << rang::fgB::yellow << m_ref_node_list.refId() << rang::style::reset
         << "\": unable to compute normal";
     throw NormalError(ost.str());
   }
 
-  const double length = sqrt(normal_l2);
+  const R3 x2 = this->_getFarestNode(mesh, x0, x1);
+
+  const R3 u = x1 - x0;
+  const R3 v = x2 - x0;
+
+  R3 normal                = crossProduct(u, v);
+  const double normal_norm = l2Norm(normal);
+
+  if (normal_norm == 0) {
+    std::ostringstream ost;
+    ost << "invalid boundary \"" << rang::fgB::yellow << m_ref_node_list.refId() << rang::style::reset
+        << "\": unable to compute normal";
+    throw NormalError(ost.str());
+  }
 
-  normal *= 1. / length;
+  normal *= (1. / normal_norm);
 
-  this->_checkBoundaryIsFlat(normal, 1. / 6. * (xmin + xmax + ymin + ymax + zmin + zmax), length, mesh);
+  this->_checkBoundaryIsFlat(normal, 1. / 3. * (x0 + x1 + x2), normal_norm, mesh);
 
   return normal;
 }
diff --git a/src/mesh/MeshFlatNodeBoundary.hpp b/src/mesh/MeshFlatNodeBoundary.hpp
index 8486f02c7377fe2987560d29f2ee7cc45c1b44cf..54a10ae1c1412005440cd9e6cc49973c017265e0 100644
--- a/src/mesh/MeshFlatNodeBoundary.hpp
+++ b/src/mesh/MeshFlatNodeBoundary.hpp
@@ -13,21 +13,26 @@ class [[nodiscard]] MeshFlatNodeBoundary final
  private:
   const Rd m_outgoing_normal;
 
+  Rd _getFarestNode(const Mesh<Connectivity<Dimension>>& mesh, const Rd& x0, const Rd& x1);
+
   Rd _getNormal(const Mesh<Connectivity<Dimension>>& mesh);
 
-  void _checkBoundaryIsFlat(const TinyVector<Dimension, double>& normal, const TinyVector<Dimension, double>& origin,
-                            const double length, const Mesh<Connectivity<Dimension>>& mesh) const;
+  void _checkBoundaryIsFlat(const TinyVector<Dimension, double>& normal,
+                            const TinyVector<Dimension, double>& origin,
+                            const double length,
+                            const Mesh<Connectivity<Dimension>>& mesh) const;
 
   Rd _getOutgoingNormal(const Mesh<Connectivity<Dimension>>& mesh);
 
  public:
-  const Rd& outgoingNormal() const
+  const Rd&
+  outgoingNormal() const
   {
     return m_outgoing_normal;
   }
 
   MeshFlatNodeBoundary& operator=(const MeshFlatNodeBoundary&) = default;
-  MeshFlatNodeBoundary& operator=(MeshFlatNodeBoundary&&) = default;
+  MeshFlatNodeBoundary& operator=(MeshFlatNodeBoundary&&)      = default;
 
   template <size_t MeshDimension>
   friend MeshFlatNodeBoundary<MeshDimension> getMeshFlatNodeBoundary(const Mesh<Connectivity<MeshDimension>>& mesh,
@@ -47,7 +52,7 @@ class [[nodiscard]] MeshFlatNodeBoundary final
  public:
   MeshFlatNodeBoundary()                            = default;
   MeshFlatNodeBoundary(const MeshFlatNodeBoundary&) = default;
-  MeshFlatNodeBoundary(MeshFlatNodeBoundary &&)     = default;
+  MeshFlatNodeBoundary(MeshFlatNodeBoundary&&)      = default;
   ~MeshFlatNodeBoundary()                           = default;
 };
 
diff --git a/tests/test_MeshFlatNodeBoundary.cpp b/tests/test_MeshFlatNodeBoundary.cpp
index 9b9169f73c4306f9591ba0a9c0f5c1c622f7ef69..760f5564226f77fa53ec8de15f3f85552f02cbc4 100644
--- a/tests/test_MeshFlatNodeBoundary.cpp
+++ b/tests/test_MeshFlatNodeBoundary.cpp
@@ -1133,6 +1133,211 @@ TEST_CASE("MeshFlatNodeBoundary", "[mesh]")
     }
   }
 
+  SECTION("rotated diamond")
+  {
+    SECTION("2D")
+    {
+      static constexpr size_t Dimension = 2;
+
+      using ConnectivityType = Connectivity<Dimension>;
+      using MeshType         = Mesh<ConnectivityType>;
+
+      using R2 = TinyVector<2>;
+
+      auto T = [](const R2& x) -> R2 { return R2{x[0] + 0.1 * x[1], x[1] + 0.1 * x[0]}; };
+
+      SECTION("cartesian 2d")
+      {
+        std::shared_ptr p_mesh = MeshDataBaseForTests::get().cartesian2DMesh();
+
+        const ConnectivityType& connectivity = p_mesh->connectivity();
+
+        auto xr = p_mesh->xr();
+
+        NodeValue<R2> rotated_xr{connectivity};
+
+        parallel_for(
+          connectivity.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) { rotated_xr[node_id] = T(xr[node_id]); });
+
+        MeshType mesh{p_mesh->shared_connectivity(), rotated_xr};
+
+        {
+          const std::set<size_t> tag_set = {0, 1, 2, 3};
+
+          for (auto tag : tag_set) {
+            NumberedBoundaryDescriptor numbered_boundary_descriptor(tag);
+            const auto& node_boundary = getMeshFlatNodeBoundary(mesh, numbered_boundary_descriptor);
+
+            auto node_list = get_node_list_from_tag(tag, connectivity);
+            REQUIRE(is_same(node_boundary.nodeList(), node_list));
+
+            R2 normal = zero;
+
+            switch (tag) {
+            case 0: {
+              normal = 1. / std::sqrt(1.01) * R2{-1, 0.1};
+              break;
+            }
+            case 1: {
+              normal = 1. / std::sqrt(1.01) * R2{1, -0.1};
+              break;
+            }
+            case 2: {
+              normal = 1. / std::sqrt(1.01) * R2{0.1, -1};
+              break;
+            }
+            case 3: {
+              normal = 1. / std::sqrt(1.01) * R2{-0.1, 1};
+              break;
+            }
+            default: {
+              FAIL("unexpected tag number");
+            }
+            }
+            REQUIRE(l2Norm(node_boundary.outgoingNormal() - normal) == Catch::Approx(0).margin(1E-13));
+          }
+        }
+
+        {
+          const std::set<std::string> name_set = {"XMIN", "XMAX", "YMIN", "YMAX"};
+
+          for (const auto& name : name_set) {
+            NamedBoundaryDescriptor named_boundary_descriptor(name);
+            const auto& node_boundary = getMeshFlatNodeBoundary(mesh, named_boundary_descriptor);
+
+            auto node_list = get_node_list_from_name(name, connectivity);
+            REQUIRE(is_same(node_boundary.nodeList(), node_list));
+
+            R2 normal = zero;
+
+            if (name == "XMIN") {
+              normal = 1. / std::sqrt(1.01) * R2{-1, 0.1};
+            } else if (name == "XMAX") {
+              normal = 1. / std::sqrt(1.01) * R2{1, -0.1};
+            } else if (name == "YMIN") {
+              normal = 1. / std::sqrt(1.01) * R2{0.1, -1};
+            } else if (name == "YMAX") {
+              normal = 1. / std::sqrt(1.01) * R2{-0.1, 1};
+            } else {
+              FAIL("unexpected name: " + name);
+            }
+
+            REQUIRE(l2Norm(node_boundary.outgoingNormal() - normal) == Catch::Approx(0).margin(1E-13));
+          }
+        }
+      }
+    }
+
+    SECTION("3D")
+    {
+      static constexpr size_t Dimension = 3;
+
+      using ConnectivityType = Connectivity<Dimension>;
+      using MeshType         = Mesh<ConnectivityType>;
+
+      using R3 = TinyVector<3>;
+
+      auto T = [](const R3& x) -> R3 {
+        return R3{x[0] + 0.1 * x[1] + 0.2 * x[2], x[1] + 0.1 * x[0] + 0.1 * x[2], x[2] + 0.1 * x[0]};
+      };
+
+      SECTION("cartesian 3d")
+      {
+        std::shared_ptr p_mesh = MeshDataBaseForTests::get().cartesian3DMesh();
+
+        const ConnectivityType& connectivity = p_mesh->connectivity();
+
+        auto xr = p_mesh->xr();
+
+        NodeValue<R3> rotated_xr{connectivity};
+
+        parallel_for(
+          connectivity.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) { rotated_xr[node_id] = T(xr[node_id]); });
+
+        MeshType mesh{p_mesh->shared_connectivity(), rotated_xr};
+
+        {
+          const std::set<size_t> tag_set = {0, 1, 2, 3, 4, 5};
+
+          for (auto tag : tag_set) {
+            NumberedBoundaryDescriptor numbered_boundary_descriptor(tag);
+            const auto& node_boundary = getMeshFlatNodeBoundary(mesh, numbered_boundary_descriptor);
+
+            auto node_list = get_node_list_from_tag(tag, connectivity);
+            REQUIRE(is_same(node_boundary.nodeList(), node_list));
+
+            R3 normal = zero;
+
+            switch (tag) {
+            case 0: {
+              normal = R3{-0.977717523265611, 0.0977717523265611, 0.185766329420466};
+              break;
+            }
+            case 1: {
+              normal = R3{0.977717523265611, -0.0977717523265612, -0.185766329420466};
+              break;
+            }
+            case 2: {
+              normal = R3{0.0911512175788074, -0.992535480302569, 0.0810233045144955};
+              break;
+            }
+            case 3: {
+              normal = R3{-0.0911512175788074, 0.992535480302569, -0.0810233045144955};
+              break;
+            }
+            case 4: {
+              normal = R3{0.100493631166705, -0.0100493631166705, -0.994886948550377};
+              break;
+            }
+            case 5: {
+              normal = R3{-0.100493631166705, 0.0100493631166705, 0.994886948550377};
+              break;
+            }
+            default: {
+              FAIL("unexpected tag number");
+            }
+            }
+
+            REQUIRE(l2Norm(node_boundary.outgoingNormal() - normal) == Catch::Approx(0).margin(1E-13));
+          }
+        }
+
+        {
+          const std::set<std::string> name_set = {"XMIN", "XMAX", "YMIN", "YMAX", "ZMIN", "ZMAX"};
+
+          for (const auto& name : name_set) {
+            NamedBoundaryDescriptor named_boundary_descriptor(name);
+            const auto& node_boundary = getMeshFlatNodeBoundary(mesh, named_boundary_descriptor);
+
+            auto node_list = get_node_list_from_name(name, connectivity);
+
+            REQUIRE(is_same(node_boundary.nodeList(), node_list));
+
+            R3 normal = zero;
+
+            if (name == "XMIN") {
+              normal = R3{-0.977717523265611, 0.0977717523265611, 0.185766329420466};
+            } else if (name == "XMAX") {
+              normal = R3{0.977717523265611, -0.0977717523265612, -0.185766329420466};
+            } else if (name == "YMIN") {
+              normal = R3{0.0911512175788074, -0.992535480302569, 0.0810233045144955};
+            } else if (name == "YMAX") {
+              normal = R3{-0.0911512175788074, 0.992535480302569, -0.0810233045144955};
+            } else if (name == "ZMIN") {
+              normal = R3{0.100493631166705, -0.0100493631166705, -0.994886948550377};
+            } else if (name == "ZMAX") {
+              normal = R3{-0.100493631166705, 0.0100493631166705, 0.994886948550377};
+            } else {
+              FAIL("unexpected name: " + name);
+            }
+
+            REQUIRE(l2Norm(node_boundary.outgoingNormal() - normal) == Catch::Approx(0).margin(1E-13));
+          }
+        }
+      }
+    }
+  }
+
   SECTION("curved mesh")
   {
     SECTION("2D")