#include <scheme/ScalarNodalScheme.hpp>

#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>

class ScalarNodalSchemeHandler::IScalarNodalScheme
{
 public:
  virtual std::shared_ptr<const IDiscreteFunction> getSolution() const = 0;

  IScalarNodalScheme()          = default;
  virtual ~IScalarNodalScheme() = default;
};

template <size_t Dimension>
class ScalarNodalSchemeHandler::ScalarNodalScheme : public ScalarNodalSchemeHandler::IScalarNodalScheme
{
 private:
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_solution;

  class DirichletBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;
    const Array<const NodeId> m_node_list;

   public:
    const Array<const NodeId>&
    nodeList() const
    {
      return m_node_list;
    }

    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    DirichletBoundaryCondition(const Array<const FaceId>& face_list,
                               const Array<const NodeId>& node_list,
                               const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}, m_node_list{node_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~DirichletBoundaryCondition() = default;
  };

  class NeumannBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;
    const Array<const NodeId> m_node_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const NodeId>&
    nodeList() const
    {
      return m_node_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    NeumannBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const NodeId>& node_list,
                             const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}, m_node_list{node_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~NeumannBoundaryCondition() = default;
  };

  class FourierBoundaryCondition
  {
   private:
    const Array<const double> m_coef_list;
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    const Array<const double>&
    coefList() const
    {
      return m_coef_list;
    }

   public:
    FourierBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const double>& coef_list,
                             const Array<const double>& value_list)
      : m_coef_list{coef_list}, m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_coef_list.size() == m_face_list.size());
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~FourierBoundaryCondition() = default;
  };

 public:
  std::shared_ptr<const IDiscreteFunction>
  getSolution() const final
  {
    return m_solution;
  }

  ScalarNodalScheme(const std::shared_ptr<const MeshType>& mesh,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k_b,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& f,
                    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    using BoundaryCondition =
      std::variant<DirichletBoundaryCondition, FourierBoundaryCondition, NeumannBoundaryCondition>;

    using BoundaryConditionList = std::vector<BoundaryCondition>;

    BoundaryConditionList boundary_condition_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());
          boundary_condition_list.push_back(
            DirichletBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Dirichlet BC in 1d");
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::fourier: {
        throw NotImplementedError("NIY");
        break;
      }
      case IBoundaryConditionDescriptor::Type::neumann: {
        const NeumannBoundaryConditionDescriptor& neumann_bc_descriptor =
          dynamic_cast<const NeumannBoundaryConditionDescriptor&>(*bc_descriptor);

        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = neumann_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_condition_list.push_back(
            NeumannBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Neumann BC in 1d");
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for heat equation";
        throw NormalError(error_msg.str());
      }
    }

    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    const NodeValue<const TinyVector<Dimension>>& xr = mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();

    const NodeValuePerCell<const TinyVector<Dimension>>& Cjr = mesh_data.Cjr();

    const auto is_boundary_node = mesh->connectivity().isBoundaryNode();

    const auto& node_to_face_matrix               = mesh->connectivity().nodeToFaceMatrix();
    const auto& face_to_node_matrix               = mesh->connectivity().faceToNodeMatrix();
    const auto& cell_to_node_matrix               = mesh->connectivity().cellToNodeMatrix();
    const auto& node_local_numbers_in_their_cells = mesh->connectivity().nodeLocalNumbersInTheirCells();
    const CellValue<const double> Vj              = mesh_data.Vj();
    const auto& node_to_cell_matrix               = mesh->connectivity().nodeToCellMatrix();

    const NodeValue<const bool> node_is_neumann = [&] {
      NodeValue<bool> compute_node_is_neumann{mesh->connectivity()};
      compute_node_is_neumann.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list = bc.faceList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];

                  compute_node_is_neumann[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
      return compute_node_is_neumann;
    }();

    const NodeValue<const bool> node_is_dirichlet = [&] {
      NodeValue<bool> compute_node_is_dirichlet{mesh->connectivity()};
      compute_node_is_dirichlet.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];

                  compute_node_is_dirichlet[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
      return compute_node_is_dirichlet;
    }();

    const NodeValue<const bool> node_is_corner = [&] {
      NodeValue<bool> compute_node_is_corner{mesh->connectivity()};
      compute_node_is_corner.fill(false);
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          if (is_boundary_node[node_id]) {
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const unsigned int i_node = node_local_number_in_its_cell[i_cell];
              const CellId cell_id      = node_to_cell[i_cell];
              const auto& cell_nodes    = cell_to_node_matrix[cell_id];
              const NodeId prev_node_id = cell_to_node_matrix[cell_id][(i_node - 1) % cell_nodes.size()];
              const NodeId next_node_id = cell_to_node_matrix[cell_id][(i_node + 1) % cell_nodes.size()];
              if (is_boundary_node[prev_node_id] and is_boundary_node[next_node_id]) {
                compute_node_is_corner[node_id] = true;
              }
            }
          }
        });
      return compute_node_is_corner;
    }();

    const NodeValue<const TinyVector<Dimension>> exterior_normal = [&] {
      NodeValue<TinyVector<Dimension>> compute_exterior_normal{mesh->connectivity()};
      compute_exterior_normal.fill(zero);
      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        if (is_boundary_node[node_id]) {
          const auto& node_to_cell                  = node_to_cell_matrix[node_id];
          const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
          for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
            const CellId cell_id      = node_to_cell[i_cell];
            const unsigned int i_node = node_local_number_in_its_cell[i_cell];
            compute_exterior_normal[node_id] += Cjr(cell_id, i_node);
          }
          const double norm_exterior_normal = l2Norm(compute_exterior_normal[node_id]);
          compute_exterior_normal[node_id] *= 1. / norm_exterior_normal;
        }
      }
      return compute_exterior_normal;
    }();

    const FaceValue<const double> mes_l = [&] {
      if constexpr (Dimension == 1) {
        FaceValue<double> compute_mes_l{mesh->connectivity()};
        compute_mes_l.fill(1);
        return compute_mes_l;
      } else {
        return mesh_data.ll();
      }
    }();

    const NodeValue<const double> node_boundary_values = [&] {
      NodeValue<double> compute_node_boundary_values{mesh->connectivity()};
      NodeValue<double> sum_mes_l{mesh->connectivity()};
      compute_node_boundary_values.fill(0);
      sum_mes_l.fill(0);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];
                  if (not node_is_dirichlet[node_id]) {
                    compute_node_boundary_values[node_id] += value_list[i_face] * mes_l[face_id];
                    sum_mes_l[node_id] += mes_l[face_id];
                  }
                }
              }

            } else if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];
                  if (not node_is_neumann[node_id]) {
                    compute_node_boundary_values[node_id] += value_list[i_face] * mes_l[face_id];
                    sum_mes_l[node_id] += mes_l[face_id];
                  } else {
                    compute_node_boundary_values[node_id] = value_list[i_face];
                  }
                }
              }
            }
          },
          boundary_condition);
      }
      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        if ((not node_is_dirichlet[node_id]) && (node_is_neumann[node_id])) {
          compute_node_boundary_values[node_id] /= sum_mes_l[node_id];
        } else if ((not node_is_neumann[node_id]) && (node_is_dirichlet[node_id])) {
          compute_node_boundary_values[node_id] /= sum_mes_l[node_id];
        }
      }
      return compute_node_boundary_values;
    }();

    {
      CellValue<const TinyMatrix<Dimension>> cell_kappaj  = cell_k->cellValues();
      CellValue<const TinyMatrix<Dimension>> cell_kappajb = cell_k_b->cellValues();

      const NodeValue<const TinyMatrix<Dimension>> node_kappar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            kappa[node_id]           = zero;
            const auto& node_to_cell = node_to_cell_matrix[node_id];
            double weight            = 0;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              double local_weight  = 1. / l2Norm(xr[node_id] - xj[cell_id]);
              kappa[node_id] += local_weight * cell_kappaj[cell_id];
              weight += local_weight;
            }
            kappa[node_id] = 1. / weight * kappa[node_id];
          });
        return kappa;
      }();

      const NodeValue<const TinyMatrix<Dimension>> node_kapparb = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            kappa[node_id]           = zero;
            const auto& node_to_cell = node_to_cell_matrix[node_id];
            double weight            = 0;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              double local_weight  = 1. / l2Norm(xr[node_id] - xj[cell_id]);
              kappa[node_id] += local_weight * cell_kappajb[cell_id];
              weight += local_weight;
            }
            kappa[node_id] = 1. / weight * kappa[node_id];
          });
        return kappa;
      }();

      const NodeValue<const TinyMatrix<Dimension>> node_betar = [&] {
        NodeValue<TinyMatrix<Dimension>> beta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            beta[node_id]                             = zero;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id      = node_to_cell[i_cell];
              const unsigned int i_node = node_local_number_in_its_cell[i_cell];
              beta[node_id] += tensorProduct(Cjr(cell_id, i_node), xr[node_id] - xj[cell_id]);
            }
          });
        return beta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> kappar_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (not node_is_corner[node_id]) {
              kappa_invBeta[node_id] = node_kappar[node_id] * inverse(node_betar[node_id]);
            }
          });
        return kappa_invBeta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> kapparb_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (not node_is_corner[node_id]) {
              kappa_invBeta[node_id] = node_kapparb[node_id] * inverse(node_betar[node_id]);
            }
          });
        return kappa_invBeta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> corner_betar = [&] {
        NodeValue<TinyMatrix<Dimension>> beta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (node_is_corner[node_id]) {
              const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
              const auto& node_to_cell                  = node_to_cell_matrix[node_id];

              size_t i_cell             = 0;
              const CellId cell_id      = node_to_cell[i_cell];
              const unsigned int i_node = node_local_number_in_its_cell[i_cell];

              const auto& cell_nodes  = cell_to_node_matrix[cell_id];
              const NodeId node_id_p1 = cell_to_node_matrix[cell_id][(i_node + 1) % cell_nodes.size()];
              const NodeId node_id_p2 = cell_to_node_matrix[cell_id][(i_node + 2) % cell_nodes.size()];
              const NodeId node_id_m1 = cell_to_node_matrix[cell_id][(i_node - 1) % cell_nodes.size()];

              const TinyVector<Dimension> xj1 = 1. / 3. * (xr[node_id] + xr[node_id_m1] + xr[node_id_p2]);
              const TinyVector<Dimension> xj2 = 1. / 3. * (xr[node_id] + xr[node_id_p2] + xr[node_id_p1]);

              const TinyVector<Dimension> xrm1 = xr[node_id_m1];
              const TinyVector<Dimension> xrp1 = xr[node_id_p1];
              const TinyVector<Dimension> xrp2 = xr[node_id_p2];

              TinyVector<Dimension> Cjr1;
              TinyVector<Dimension> Cjr2;

              if (Dimension == 2) {
                Cjr1[0] = -0.5 * (xrm1[1] - xrp2[1]);
                Cjr1[1] = 0.5 * (xrm1[0] - xrp2[0]);
                Cjr2[0] = -0.5 * (xrp2[1] - xrp1[1]);
                Cjr2[1] = 0.5 * (xrp2[0] - xrp1[0]);
              } else {
                throw NotImplementedError("The scheme is not implemented in this dimension.");
              }

              beta[node_id] = tensorProduct(Cjr1, (xr[node_id] - xj1)) + tensorProduct(Cjr2, (xr[node_id] - xj2));
            }
          });
        return beta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> corner_kappar_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (node_is_corner[node_id]) {
              kappa_invBeta[node_id] = node_kappar[node_id] * inverse(corner_betar[node_id]);
            }
          });
        return kappa_invBeta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> corner_kapparb_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (node_is_corner[node_id]) {
              kappa_invBeta[node_id] = node_kapparb[node_id] * inverse(corner_betar[node_id]);
            }
          });
        return kappa_invBeta;
      }();

      const NodeValue<const TinyVector<Dimension>> sum_Cjr = [&] {
        NodeValue<TinyVector<Dimension>> compute_sum_Cjr{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            compute_sum_Cjr[node_id]                  = zero;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              const size_t i_node  = node_local_number_in_its_cell[i_cell];
              compute_sum_Cjr[node_id] += Cjr(cell_id, i_node);
            }
          });
        return compute_sum_Cjr;
      }();

      const NodeValue<const TinyVector<Dimension>> v = [&] {
        NodeValue<TinyVector<Dimension>> compute_v{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (is_boundary_node[node_id]) {
              compute_v[node_id] = 1. / l2Norm(node_kapparb[node_id] * exterior_normal[node_id]) *
                                   node_kapparb[node_id] * exterior_normal[node_id];
            }
          });
        return compute_v;
      }();

      const NodeValuePerCell<const double> theta = [&] {
        NodeValuePerCell<double> compute_theta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
            const auto& cell_nodes = cell_to_node_matrix[cell_id];
            for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
              const NodeId node_id = cell_nodes[i_node];
              if (is_boundary_node[node_id] && not node_is_corner[node_id]) {
                compute_theta(cell_id, i_node) = dot(inverse(node_betar[node_id]) * Cjr(cell_id, i_node), v[node_id]);
              }
            }
          });
        return compute_theta;
      }();

      const NodeValue<const double> sum_theta = [&] {
        NodeValue<double> compute_sum_theta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if ((is_boundary_node[node_id]) && (not node_is_corner[node_id])) {
              const auto& node_to_cell                  = node_to_cell_matrix[node_id];
              const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
              compute_sum_theta[node_id]                = 0;
              for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
                const CellId cell_id = node_to_cell[i_cell];
                const size_t i_node  = node_local_number_in_its_cell[i_cell];
                compute_sum_theta[node_id] += theta(cell_id, i_node);
              }
            }
          });
        return compute_sum_theta;
      }();

      const Array<int> non_zeros{mesh->numberOfCells()};
      non_zeros.fill(Dimension);   // Modif pour optimiser
      CRSMatrixDescriptor<double> S(mesh->numberOfCells(), mesh->numberOfCells(), non_zeros);

      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell                   = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
        if (not is_boundary_node[node_id]) {
          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S(cell_id_j, cell_id_k) +=
                dot(kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
            }
          }
        } else if ((node_is_neumann[node_id]) && (not node_is_corner[node_id]) && (not node_is_dirichlet[node_id])) {
          const TinyMatrix<Dimension> I   = identity;
          const TinyVector<Dimension> n   = exterior_normal[node_id];
          const TinyMatrix<Dimension> nxn = tensorProduct(n, n);
          const TinyMatrix<Dimension> P   = I - nxn;

          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S(cell_id_j, cell_id_k) +=
                dot(kappar_invBetar[node_id] *
                      (Cjr(cell_id_k, i_node_k) - theta(cell_id_k, i_node_k) / sum_theta[node_id] * sum_Cjr[node_id]),
                    P * Cjr(cell_id_k, i_node_j));
            }
          }
        } else if ((node_is_dirichlet[node_id]) && (not node_is_corner[node_id])) {
          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S(cell_id_j, cell_id_k) +=
                dot(kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
            }
          }
        } else if (node_is_dirichlet[node_id] && node_is_corner[node_id]) {
          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S(cell_id_j, cell_id_k) +=
                dot(corner_kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
            }
          }
        }
      }

      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        S(cell_id, cell_id) += Vj[cell_id];
      }

      CellValue<const double> fj = f->cellValues();

      CRSMatrix A{S.getCRSMatrix()};

      Vector<double> b{mesh->numberOfCells()};
      b = zero;
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        b[cell_id] = fj[cell_id] * Vj[cell_id];
      };

      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell                   = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
        const TinyMatrix<Dimension> I              = identity;
        const TinyVector<Dimension> n              = exterior_normal[node_id];
        const TinyMatrix<Dimension> nxn            = tensorProduct(n, n);
        const TinyMatrix<Dimension> P              = I - nxn;
        if ((node_is_neumann[node_id]) && (not node_is_dirichlet[node_id])) {
          if ((is_boundary_node[node_id]) and (not node_is_corner[node_id])) {
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              const size_t i_node  = node_local_number_in_its_cells[i_cell];

              b[cell_id] += node_boundary_values[node_id] *
                            (1. / (sum_theta[node_id] * l2Norm(node_kapparb[node_id] * n)) *
                               dot(kapparb_invBetar[node_id] * sum_Cjr[node_id], P * Cjr(cell_id, i_node)) +
                             dot(Cjr(cell_id, i_node), n));
            }
          } else if (node_is_corner[node_id]) {
            const auto& node_to_face = node_to_face_matrix[node_id];
            const CellId cell_id     = node_to_cell[0];
            double sum_mes_l         = 0;
            for (size_t i_face = 0; i_face < node_to_face.size(); ++i_face) {
              FaceId face_id = node_to_face[i_face];
              sum_mes_l += mes_l[face_id];
            }
            b[cell_id] += 0.5 * node_boundary_values[node_id] * sum_mes_l;
          }
        } else if (node_is_dirichlet[node_id]) {
          if (not node_is_corner[node_id]) {
            for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
              const CellId cell_id_j = node_to_cell[i_cell_j];
              const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

              for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
                const CellId cell_id_k = node_to_cell[i_cell_k];
                const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

                b[cell_id_j] +=
                  dot(node_boundary_values[node_id] * kapparb_invBetar[node_id] * Cjr(cell_id_k, i_node_k),
                      Cjr(cell_id_j, i_node_j));
              }
            }

          } else if (node_is_corner[node_id]) {
            for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
              const CellId cell_id_j = node_to_cell[i_cell_j];
              const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

              for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
                const CellId cell_id_k = node_to_cell[i_cell_k];
                const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

                b[cell_id_j] +=
                  dot(node_boundary_values[node_id] * corner_kapparb_invBetar[node_id] * Cjr(cell_id_k, i_node_k),
                      Cjr(cell_id_j, i_node_j));
              }
            }
          }
        };

        Vector<double> T{mesh->numberOfCells()};
        T = zero;

        LinearSolver solver;
        solver.solveLocalSystem(A, T, b);

        m_solution     = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
        auto& solution = *m_solution;
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { solution[cell_id] = T[cell_id]; });
      }
    }
  }
};

std::shared_ptr<const IDiscreteFunction>
ScalarNodalSchemeHandler::solution() const
{
  return m_scheme->getSolution();
}

ScalarNodalSchemeHandler::ScalarNodalSchemeHandler(
  const std::shared_ptr<const IDiscreteFunction>& cell_k_b,
  const std::shared_ptr<const IDiscreteFunction>& cell_k,
  const std::shared_ptr<const IDiscreteFunction>& f,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
{
  const std::shared_ptr i_mesh = getCommonMesh({cell_k, f});
  if (not i_mesh) {
    throw NormalError("primal discrete functions are not defined on the same mesh");
  }

  checkDiscretizationType({cell_k_b, cell_k, f}, DiscreteFunctionType::P0);

  switch (i_mesh->dimension()) {
  case 1: {
    throw NormalError("The scheme is not implemented in dimension 1");
    break;
  }
  case 2: {
    using MeshType                   = Mesh<Connectivity<2>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<2, TinyMatrix<2>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    m_scheme =
      std::make_unique<ScalarNodalScheme<2>>(mesh,
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_b),
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                             bc_descriptor_list);
    break;
  }
  case 3: {
    using MeshType                   = Mesh<Connectivity<3>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<3, TinyMatrix<3>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<3, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    m_scheme =
      std::make_unique<ScalarNodalScheme<3>>(mesh,
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_b),
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                             bc_descriptor_list);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

ScalarNodalSchemeHandler::~ScalarNodalSchemeHandler() = default;
