diff --git a/src/mesh/MeshFlatNodeBoundary.cpp b/src/mesh/MeshFlatNodeBoundary.cpp
index b7de42715eff88510119b54507ed32323365c817..459a827c84ead76ed23acdf28cbd8525923694da 100644
--- a/src/mesh/MeshFlatNodeBoundary.cpp
+++ b/src/mesh/MeshFlatNodeBoundary.cpp
@@ -87,57 +87,108 @@ MeshFlatNodeBoundary<2>::_getNormal(const Mesh<Connectivity<2>>& mesh)
 
 template <>
 TinyVector<3, double>
-MeshFlatNodeBoundary<3>::_getNormal(const Mesh<Connectivity<3>>& mesh)
+MeshFlatNodeBoundary<3>::_getFarestNode(const Mesh<Connectivity<3>>& mesh, const Rd& x0, const Rd& x1)
 {
-  using R3 = TinyVector<3, double>;
-
-  std::array<R3, 6> bounds = this->_getBounds(mesh);
+  const NodeValue<const Rd>& xr = mesh.xr();
+  const auto node_number        = mesh.connectivity().nodeNumber();
 
-  const R3& xmin = bounds[0];
-  const R3& ymin = bounds[1];
-  const R3& zmin = bounds[2];
-  const R3& xmax = bounds[3];
-  const R3& ymax = bounds[4];
-  const R3& zmax = bounds[5];
+  using NodeNumberType = std::remove_const_t<typename decltype(node_number)::data_type>;
 
-  const R3 u = xmax - xmin;
-  const R3 v = ymax - ymin;
-  const R3 w = zmax - zmin;
+  Rd t = x1 - x0;
+  t *= 1. / l2Norm(t);
 
-  const R3 uv        = crossProduct(u, v);
-  const double uv_l2 = dot(uv, uv);
+  double farest_distance       = 0;
+  Rd farest_x                  = zero;
+  NodeNumberType farest_number = std::numeric_limits<NodeNumberType>::max();
 
-  R3 normal        = uv;
-  double normal_l2 = uv_l2;
+  auto node_list = this->m_ref_node_list.list();
 
-  const R3 uw        = crossProduct(u, w);
-  const double uw_l2 = dot(uw, uw);
+  for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
+    const NodeId& node_id = node_list[i_node];
+    const Rd& x           = xr[node_id];
+    const double distance = l2Norm(crossProduct(t, x - x0));
 
-  if (uw_l2 > uv_l2) {
-    normal    = uw;
-    normal_l2 = uw_l2;
+    if ((distance > farest_distance) or ((distance == farest_distance) and (node_number[node_id] < farest_number))) {
+      farest_distance = distance;
+      farest_number   = node_number[node_id];
+      farest_x        = x;
+    }
   }
 
-  const R3 vw        = crossProduct(v, w);
-  const double vw_l2 = dot(vw, vw);
+  if (parallel::size()) {
+    Array<double> farest_distance_array       = parallel::allGather(farest_distance);
+    Array<Rd> farest_x_array                  = parallel::allGather(farest_x);
+    Array<NodeNumberType> farest_number_array = parallel::allGather(farest_number);
 
-  if (vw_l2 > normal_l2) {
-    normal    = vw;
-    normal_l2 = vw_l2;
+    Assert(farest_distance_array.size() == farest_x_array.size());
+    Assert(farest_distance_array.size() == farest_number_array.size());
+
+    for (size_t i = 0; i < farest_distance_array.size(); ++i) {
+      if ((farest_distance_array[i] > farest_distance) or
+          ((farest_distance_array[i] == farest_distance) and (farest_number_array[i] < farest_number))) {
+        farest_distance = farest_distance_array[i];
+        farest_number   = farest_number_array[i];
+        farest_x        = farest_x_array[i];
+      }
+    }
   }
 
-  if (normal_l2 == 0) {
+  return farest_x;
+}
+
+template <>
+TinyVector<3, double>
+MeshFlatNodeBoundary<3>::_getNormal(const Mesh<Connectivity<3>>& mesh)
+{
+  using R3 = TinyVector<3, double>;
+
+  std::array<R3, 2> diagonal = [](const std::array<R3, 6>& bounds) {
+    size_t max_i      = 0;
+    size_t max_j      = 0;
+    double max_length = 0;
+
+    for (size_t i = 0; i < bounds.size(); ++i) {
+      for (size_t j = i + 1; j < bounds.size(); ++j) {
+        double length = l2Norm(bounds[i] - bounds[j]);
+        if (length > max_length) {
+          max_i      = i;
+          max_j      = j;
+          max_length = length;
+        }
+      }
+    }
+
+    return std::array<R3, 2>{bounds[max_i], bounds[max_j]};
+  }(this->_getBounds(mesh));
+
+  const R3& x0 = diagonal[0];
+  const R3& x1 = diagonal[1];
+
+  if (x0 == x1) {
     std::ostringstream ost;
     ost << "invalid boundary \"" << rang::fgB::yellow << m_ref_node_list.refId() << rang::style::reset
         << "\": unable to compute normal";
     throw NormalError(ost.str());
   }
 
-  const double length = sqrt(normal_l2);
+  const R3 x2 = this->_getFarestNode(mesh, x0, x1);
+
+  const R3 u = x1 - x0;
+  const R3 v = x2 - x0;
+
+  R3 normal                = crossProduct(u, v);
+  const double normal_norm = l2Norm(normal);
+
+  if (normal_norm == 0) {
+    std::ostringstream ost;
+    ost << "invalid boundary \"" << rang::fgB::yellow << m_ref_node_list.refId() << rang::style::reset
+        << "\": unable to compute normal";
+    throw NormalError(ost.str());
+  }
 
-  normal *= 1. / length;
+  normal *= (1. / normal_norm);
 
-  this->_checkBoundaryIsFlat(normal, 1. / 6. * (xmin + xmax + ymin + ymax + zmin + zmax), length, mesh);
+  this->_checkBoundaryIsFlat(normal, 1. / 3. * (x0 + x1 + x2), normal_norm, mesh);
 
   return normal;
 }
diff --git a/src/mesh/MeshFlatNodeBoundary.hpp b/src/mesh/MeshFlatNodeBoundary.hpp
index 8486f02c7377fe2987560d29f2ee7cc45c1b44cf..54a10ae1c1412005440cd9e6cc49973c017265e0 100644
--- a/src/mesh/MeshFlatNodeBoundary.hpp
+++ b/src/mesh/MeshFlatNodeBoundary.hpp
@@ -13,21 +13,26 @@ class [[nodiscard]] MeshFlatNodeBoundary final
  private:
   const Rd m_outgoing_normal;
 
+  Rd _getFarestNode(const Mesh<Connectivity<Dimension>>& mesh, const Rd& x0, const Rd& x1);
+
   Rd _getNormal(const Mesh<Connectivity<Dimension>>& mesh);
 
-  void _checkBoundaryIsFlat(const TinyVector<Dimension, double>& normal, const TinyVector<Dimension, double>& origin,
-                            const double length, const Mesh<Connectivity<Dimension>>& mesh) const;
+  void _checkBoundaryIsFlat(const TinyVector<Dimension, double>& normal,
+                            const TinyVector<Dimension, double>& origin,
+                            const double length,
+                            const Mesh<Connectivity<Dimension>>& mesh) const;
 
   Rd _getOutgoingNormal(const Mesh<Connectivity<Dimension>>& mesh);
 
  public:
-  const Rd& outgoingNormal() const
+  const Rd&
+  outgoingNormal() const
   {
     return m_outgoing_normal;
   }
 
   MeshFlatNodeBoundary& operator=(const MeshFlatNodeBoundary&) = default;
-  MeshFlatNodeBoundary& operator=(MeshFlatNodeBoundary&&) = default;
+  MeshFlatNodeBoundary& operator=(MeshFlatNodeBoundary&&)      = default;
 
   template <size_t MeshDimension>
   friend MeshFlatNodeBoundary<MeshDimension> getMeshFlatNodeBoundary(const Mesh<Connectivity<MeshDimension>>& mesh,
@@ -47,7 +52,7 @@ class [[nodiscard]] MeshFlatNodeBoundary final
  public:
   MeshFlatNodeBoundary()                            = default;
   MeshFlatNodeBoundary(const MeshFlatNodeBoundary&) = default;
-  MeshFlatNodeBoundary(MeshFlatNodeBoundary &&)     = default;
+  MeshFlatNodeBoundary(MeshFlatNodeBoundary&&)      = default;
   ~MeshFlatNodeBoundary()                           = default;
 };