#include <scheme/ScalarNodalScheme.hpp>

#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>

class ScalarNodalSchemeHandler::IScalarNodalScheme
{
 public:
  virtual std::shared_ptr<const IDiscreteFunction> getSolution() const = 0;

  IScalarNodalScheme()          = default;
  virtual ~IScalarNodalScheme() = default;
};

template <size_t Dimension>
class ScalarNodalSchemeHandler::ScalarNodalScheme : public ScalarNodalSchemeHandler::IScalarNodalScheme
{
 private:
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_solution;

  class DirichletBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    DirichletBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~DirichletBoundaryCondition() = default;
  };

  class NeumannBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;
    const Array<const NodeId> m_node_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const NodeId>&
    nodeList() const
    {
      return m_node_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    NeumannBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const NodeId>& node_list,
                             const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}, m_node_list{node_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~NeumannBoundaryCondition() = default;
  };

  class FourierBoundaryCondition
  {
   private:
    const Array<const double> m_coef_list;
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    const Array<const double>&
    coefList() const
    {
      return m_coef_list;
    }

   public:
    FourierBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const double>& coef_list,
                             const Array<const double>& value_list)
      : m_coef_list{coef_list}, m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_coef_list.size() == m_face_list.size());
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~FourierBoundaryCondition() = default;
  };

 public:
  std::shared_ptr<const IDiscreteFunction>
  getSolution() const final
  {
    return m_solution;
  }

  ScalarNodalScheme(const std::shared_ptr<const MeshType>& mesh,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& alpha,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k_bound,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& f,
                    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    using BoundaryCondition =
      std::variant<DirichletBoundaryCondition, FourierBoundaryCondition, NeumannBoundaryCondition>;

    using BoundaryConditionList = std::vector<BoundaryCondition>;

    BoundaryConditionList boundary_condition_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_condition_list.push_back(DirichletBoundaryCondition{mesh_face_boundary.faceList(), value_list});

        } else {
          throw NotImplementedError("Dirichlet BC in 1d");
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::fourier: {
        throw NotImplementedError("NIY");
        break;
      }
      case IBoundaryConditionDescriptor::Type::neumann: {
        const NeumannBoundaryConditionDescriptor& neumann_bc_descriptor =
          dynamic_cast<const NeumannBoundaryConditionDescriptor&>(*bc_descriptor);

        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = neumann_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_condition_list.push_back(
            NeumannBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Neumann BC in 1d");
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for heat equation";
        throw NormalError(error_msg.str());
      }
    }

    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    const NodeValue<const TinyVector<Dimension>>& xr = mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();

    const NodeValuePerCell<const TinyVector<Dimension>>& Cjr = mesh_data.Cjr();

    const auto is_boundary_node = mesh->connectivity().isBoundaryNode();

    // const auto& node_to_face_matrix               = mesh->connectivity().nodeToFaceMatrix();
    const auto& face_to_node_matrix               = mesh->connectivity().faceToNodeMatrix();
    const auto& cell_to_node_matrix               = mesh->connectivity().cellToNodeMatrix();
    const auto& node_local_numbers_in_their_cells = mesh->connectivity().nodeLocalNumbersInTheirCells();
    const CellValue<const double> Vj              = mesh_data.Vj();
    const auto& node_to_cell_matrix               = mesh->connectivity().nodeToCellMatrix();

    const NodeValue<const bool> node_is_neumann = [&] {
      NodeValue<bool> compute_node_is_neumann{mesh->connectivity()};
      compute_node_is_neumann.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];

                  compute_node_is_neumann[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
      return compute_node_is_neumann;
    }();

    const NodeValue<const bool> node_is_dirichlet = [&] {
      NodeValue<bool> compute_node_is_dirichlet{mesh->connectivity()};
      compute_node_is_dirichlet.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];

                  compute_node_is_dirichlet[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
      return compute_node_is_dirichlet;
    }();

    const NodeValue<const bool> node_is_corner = [&] {
      NodeValue<bool> compute_node_is_corner{mesh->connectivity()};
      compute_node_is_corner.fill(false);
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          if (is_boundary_node[node_id]) {
            const auto& node_to_cell = node_to_cell_matrix[node_id];

            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id      = node_to_cell[i_cell];
              const auto& cell_nodes    = cell_to_node_matrix[cell_id];
              const NodeId prev_node_id = cell_to_node_matrix[cell_id][(node_id - 1) % cell_nodes.size()];
              const NodeId next_node_id = cell_to_node_matrix[cell_id][(node_id + 1) % cell_nodes.size()];

              if (is_boundary_node[prev_node_id] and is_boundary_node[next_node_id]) {
                compute_node_is_corner[node_id] = true;
              }
            }
          }
        });

      return compute_node_is_corner;
    }();

    const NodeValue<const TinyVector<Dimension>> exterior_normal = [&] {
      NodeValue<TinyVector<Dimension>> compute_exterior_normal;
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          compute_exterior_normal[node_id]          = zero;
          const auto& node_to_cell                  = node_to_cell_matrix[node_id];
          const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
          for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
            const CellId cell_id          = node_to_cell[i_cell];
            const unsigned int i_node_loc = node_local_number_in_its_cell[cell_id];
            compute_exterior_normal[node_id] += Cjr(cell_id, i_node_loc);
          }
          compute_exterior_normal[node_id] =
            1. / l2Norm(compute_exterior_normal[node_id]) * compute_exterior_normal[node_id];
        });
      return compute_exterior_normal;
    }();

    const FaceValue<const double> mes_l = [&] {
      if constexpr (Dimension == 1) {
        FaceValue<double> compute_mes_l{mesh->connectivity()};
        compute_mes_l.fill(1);
        return compute_mes_l;
      } else {
        return mesh_data.ll();
      }
    }();

    const NodeValue<const double> node_boundary_values = [&] {
      NodeValue<double> compute_node_boundary_values;
      compute_node_boundary_values.fill(0);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];
                  if (node_is_dirichlet[node_id] == false) {
                    compute_node_boundary_values[node_id] += 0.5 * value_list[i_face] * mes_l[face_id];
                  }
                }
              }
            } else if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];
                  if (node_is_neumann[node_id] == false) {
                    compute_node_boundary_values[node_id] += 0.5 * value_list[i_face] * mes_l[face_id];
                  } else {
                    compute_node_boundary_values[node_id] = value_list[i_face] * mes_l[face_id];
                  }
                }
              }
            }
          },
          boundary_condition);
      }
      return compute_node_boundary_values;
    }();

    {
      CellValue<const TinyMatrix<Dimension>> cell_kappaj  = cell_k->cellValues();
      CellValue<const TinyMatrix<Dimension>> cell_kappajb = cell_k_bound->cellValues();

      const NodeValue<const TinyMatrix<Dimension>> node_kappar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa;
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            kappa[node_id]           = zero;
            const auto& node_to_cell = node_to_cell_matrix[node_id];
            double weight            = 0;
            for (size_t j = 0; j < node_to_cell.size(); ++j) {
              const CellId J      = node_to_cell[j];
              double local_weight = 1. / l2Norm(xr[node_id] - xj[J]);
              kappa[node_id] += local_weight * cell_kappaj[J];
              weight += local_weight;
            }
            kappa[node_id] = 1. / weight * kappa[node_id];
          });
        return kappa;
      }();

      const NodeValue<const TinyMatrix<Dimension>> node_kapparb = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa;
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            kappa[node_id]           = zero;
            const auto& node_to_cell = node_to_cell_matrix[node_id];
            double weight            = 0;
            for (size_t j = 0; j < node_to_cell.size(); ++j) {
              const CellId J      = node_to_cell[j];
              double local_weight = 1. / l2Norm(xr[node_id] - xj[J]);
              kappa[node_id] += local_weight * cell_kappajb[J];
              weight += local_weight;
            }
            kappa[node_id] = 1. / weight * kappa[node_id];
          });
        return kappa;
      }();

      const NodeValue<const TinyMatrix<Dimension>> node_betar = [&] {
        NodeValue<TinyMatrix<Dimension>> beta;
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            beta[node_id]                             = zero;
            for (size_t j = 0; j < node_to_cell.size(); ++j) {
              const CellId J       = node_to_cell[j];
              const unsigned int R = node_local_number_in_its_cell[j];
              beta[node_id] += tensorProduct(Cjr(J, R), xr[node_id] - xj[J]);
            }
          });
        return beta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> kappar_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta;
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            kappa_invBeta[node_id] = node_kappar[node_id] * inverse(node_betar[node_id]);
          });
        return kappa_invBeta;
      }();

      const Array<int> non_zeros{mesh->numberOfCells()};
      non_zeros.fill(Dimension);   // Modif pour optimiser
      CRSMatrixDescriptor<double> S(mesh->numberOfCells(), mesh->numberOfCells(), non_zeros);

      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell = node_to_cell_matrix[node_id];

        if (not is_boundary_node[node_id]) {
          for (size_t cell_id_j = 0; cell_id_j < node_to_cell.size(); ++cell_id_j) {
            const CellId J                          = node_to_cell[cell_id_j];
            const auto& node_local_number_in_j_cell = node_local_numbers_in_their_cells.itemArray(J);
            const size_t node_id_j                  = node_local_number_in_j_cell[node_id];

            for (size_t cell_id_k = 0; cell_id_k < node_to_cell.size(); ++cell_id_k) {
              const CellId K                          = node_to_cell[cell_id_k];
              const auto& node_local_number_in_k_cell = node_local_numbers_in_their_cells.itemArray(K);
              const size_t node_id_k                  = node_local_number_in_k_cell[node_id];

              S(J, K) += dot(kappar_invBetar[node_id] * Cjr(K, node_id_k), Cjr(J, node_id_j));
            }
          }
        } else if (not node_is_corner[node_id]) {
        }
      }

      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& node_list = bc.nodeList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId node_id     = node_list[i_node];
                const auto& node_to_cell = node_to_cell_matrix[node_id];

                for (size_t cell_id_j = 0; cell_id_j < node_to_cell.size(); ++cell_id_j) {
                  const CellId J                          = node_to_cell[cell_id_j];
                  const auto& node_local_number_in_j_cell = node_local_numbers_in_their_cells.itemArray(J);
                  const size_t node_id_j                  = node_local_number_in_j_cell[node_id];

                  for (size_t cell_id_k = 0; cell_id_k < node_to_cell.size(); ++cell_id_k) {
                    const CellId K                          = node_to_cell[cell_id_k];
                    const auto& node_local_number_in_k_cell = node_local_numbers_in_their_cells.itemArray(K);
                    const size_t node_id_k                  = node_local_number_in_k_cell[node_id];

                    const TinyMatrix<Dimension> I   = identity;
                    const TinyVector<Dimension> n   = exterior_normal[node_id];
                    const TinyMatrix<Dimension> nxn = tensorProduct(n, n);
                    const TinyMatrix<Dimension> P   = I - nxn;

                    S(J, K) += dot(kappar_invBetar[node_id] * Cjr(K, node_id_k), P * Cjr(J, node_id_j));
                  }
                }
              }
            } else if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& node_list = bc.faceList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId node_id     = node_list[i_node];
                const auto& node_to_cell = node_to_cell_matrix[node_id];

                for (size_t cell_id_j = 0; cell_id_j < node_to_cell.size(); ++cell_id_j) {
                  const CellId J                          = node_to_cell[cell_id_j];
                  const auto& node_local_number_in_j_cell = node_local_numbers_in_their_cells.itemArray(J);
                  const size_t node_id_j                  = node_local_number_in_j_cell[node_id];

                  for (size_t cell_id_k = 0; cell_id_k < node_to_cell.size(); ++cell_id_k) {
                    const CellId K                          = node_to_cell[cell_id_k];
                    const auto& node_local_number_in_k_cell = node_local_numbers_in_their_cells.itemArray(K);
                    const size_t node_id_k                  = node_local_number_in_k_cell[node_id];

                    S(J, K) += dot(kappar_invBetar[node_id] * Cjr(K, node_id_k), Cjr(J, node_id_j));
                  }
                }
              }
            }
          },
          boundary_condition);
      }

      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        S(cell_id, cell_id) += Vj[cell_id];
      }

      CellValue<const double> fj = f->cellValues();

      CRSMatrix A{S.getCRSMatrix()};
      Vector<double> b{mesh->numberOfCells()};
      b = zero;
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        b[cell_id] = fj[cell_id] * Vj[cell_id];
      }

      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {   // To do
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];
                b[face_id] += value_list[i_face];
              }
            } else if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& node_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                NodeId node_id                = node_list[i_node];
                const auto& node_to_cell      = node_to_cell_matrix[node_id];
                const TinyVector<Dimension> n = exterior_normal[node_id];

                for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
                  CellId J                              = node_to_cell[i_cell];
                  const auto& node_local_number_in_cell = node_local_numbers_in_their_cells.itemArray(J);
                  const size_t node_id_j                = node_local_number_in_cell[node_id];
                  b[J] -= dot(Cjr(J, node_id_j), n) * value_list[i_node];
                }
              }
            }
          },
          boundary_condition);
      }

      Vector<double> T{mesh->numberOfCells()};
      T = zero;

      LinearSolver solver;
      solver.solveLocalSystem(A, T, b);

      m_solution     = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
      auto& solution = *m_solution;
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { solution[cell_id] = T[cell_id]; });
    }
  }
};

std::shared_ptr<const IDiscreteFunction>
ScalarNodalSchemeHandler::solution() const
{
  return m_scheme->getSolution();
}

ScalarNodalSchemeHandler::ScalarNodalSchemeHandler(
  const std::shared_ptr<const IDiscreteFunction>& alpha,
  const std::shared_ptr<const IDiscreteFunction>& cell_k_bound,
  const std::shared_ptr<const IDiscreteFunction>& cell_k,
  const std::shared_ptr<const IDiscreteFunction>& f,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
{
  const std::shared_ptr i_mesh = getCommonMesh({alpha, f});
  if (not i_mesh) {
    throw NormalError("primal discrete functions are not defined on the same mesh");
  }

  checkDiscretizationType({alpha, cell_k_bound, cell_k, f}, DiscreteFunctionType::P0);

  switch (i_mesh->dimension()) {
  case 1: {
    throw NormalError("The scheme is not implemented in dimension 1");
    break;
  }
  case 2: {
    using MeshType                   = Mesh<Connectivity<2>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<2, TinyMatrix<2>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    m_scheme =
      std::make_unique<ScalarNodalScheme<2>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_bound),
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                             bc_descriptor_list);
    break;
  }
  case 3: {
    using MeshType                   = Mesh<Connectivity<3>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<3, TinyMatrix<3>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<3, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    m_scheme =
      std::make_unique<ScalarNodalScheme<3>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_bound),
                                             std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                             bc_descriptor_list);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

ScalarNodalSchemeHandler::~ScalarNodalSchemeHandler() = default;