#include <scheme/ScalarNodalScheme.hpp>

#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>

class ScalarNodalSchemeHandler::IScalarNodalScheme
{
 public:
  virtual std::shared_ptr<const IDiscreteFunction> getSolution() const = 0;

  IScalarNodalScheme()          = default;
  virtual ~IScalarNodalScheme() = default;
};

template <size_t Dimension>
class ScalarNodalSchemeHandler::ScalarNodalScheme : public ScalarNodalSchemeHandler::IScalarNodalScheme
{
 private:
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_solution;

  class DirichletBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;
    const Array<const NodeId> m_node_list;

   public:
    const Array<const NodeId>&
    nodeList() const
    {
      return m_node_list;
    }

    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    DirichletBoundaryCondition(const Array<const FaceId>& face_list,
                               const Array<const NodeId>& node_list,
                               const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}, m_node_list{node_list}
    {
      Assert(m_value_list.size() == m_node_list.size());
    }

    ~DirichletBoundaryCondition() = default;
  };

  class NeumannBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;
    const Array<const NodeId> m_node_list;

   public:
    const Array<const NodeId>&
    nodeList() const
    {
      return m_node_list;
    }

    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    NeumannBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const NodeId>& node_list,
                             const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}, m_node_list{node_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~NeumannBoundaryCondition() = default;
  };

  class FourierBoundaryCondition
  {
   private:
    const Array<const double> m_coef_list;
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    const Array<const double>&
    coefList() const
    {
      return m_coef_list;
    }

   public:
    FourierBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const double>& coef_list,
                             const Array<const double>& value_list)
      : m_coef_list{coef_list}, m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_coef_list.size() == m_face_list.size());
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~FourierBoundaryCondition() = default;
  };

 public:
  std::shared_ptr<const IDiscreteFunction>
  getSolution() const final
  {
    return m_solution;
  }

  ScalarNodalScheme(const std::shared_ptr<const MeshType>& mesh,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& f,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& f_prev,
                    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_prev_descriptor_list,
                    const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& P,
                    const double& dt,
                    const double& cn_coeff)
  {
    using BoundaryCondition =
      std::variant<DirichletBoundaryCondition, FourierBoundaryCondition, NeumannBoundaryCondition>;

    using BoundaryConditionList = std::vector<BoundaryCondition>;

    BoundaryConditionList boundary_condition_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();
          // MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::node>(g_id, mesh->xr(),
                                                                                                      mesh_node_boundary
                                                                                                        .nodeList());
          boundary_condition_list.push_back(
            DirichletBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Dirichlet BC in 1d");
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::fourier: {
        throw NotImplementedError("NIY");
        break;
      }
      case IBoundaryConditionDescriptor::Type::neumann: {
        const NeumannBoundaryConditionDescriptor& neumann_bc_descriptor =
          dynamic_cast<const NeumannBoundaryConditionDescriptor&>(*bc_descriptor);

        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = neumann_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_condition_list.push_back(
            NeumannBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Neumann BC in 1d");
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for heat equation";
        throw NormalError(error_msg.str());
      }
    }

    BoundaryConditionList boundary_prev_condition_list;

    for (const auto& bc_prev_descriptor : bc_prev_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_prev_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_prev_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_prev_descriptor);
        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_prev_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_prev_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = dirichlet_bc_prev_descriptor.rhsSymbolId();
          //          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::node>(g_id, mesh->xr(),
                                                                                                      mesh_node_boundary
                                                                                                        .nodeList());
          boundary_prev_condition_list.push_back(
            DirichletBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Dirichlet BC in 1d");
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::fourier: {
        throw NotImplementedError("NIY");
        break;
      }
      case IBoundaryConditionDescriptor::Type::neumann: {
        const NeumannBoundaryConditionDescriptor& neumann_bc_prev_descriptor =
          dynamic_cast<const NeumannBoundaryConditionDescriptor&>(*bc_prev_descriptor);

        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, neumann_bc_prev_descriptor.boundaryDescriptor());
          MeshNodeBoundary<Dimension> mesh_node_boundary =
            getMeshNodeBoundary(*mesh, neumann_bc_prev_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = neumann_bc_prev_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_prev_condition_list.push_back(
            NeumannBoundaryCondition{mesh_face_boundary.faceList(), mesh_node_boundary.nodeList(), value_list});

        } else {
          throw NotImplementedError("Neumann BC in 1d");
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_prev_descriptor << " is an invalid boundary condition for heat equation";
        throw NormalError(error_msg.str());
      }
    }

    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    const NodeValue<const TinyVector<Dimension>>& xr = mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();

    const NodeValuePerCell<const TinyVector<Dimension>>& Cjr = mesh_data.Cjr();

    const auto is_boundary_node = mesh->connectivity().isBoundaryNode();
    const auto is_boundary_face = mesh->connectivity().isBoundaryFace();

    const auto& face_to_cell_matrix               = mesh->connectivity().faceToCellMatrix();
    const auto& node_to_face_matrix               = mesh->connectivity().nodeToFaceMatrix();
    const auto& face_to_node_matrix               = mesh->connectivity().faceToNodeMatrix();
    const auto& cell_to_node_matrix               = mesh->connectivity().cellToNodeMatrix();
    const auto& node_local_numbers_in_their_cells = mesh->connectivity().nodeLocalNumbersInTheirCells();
    const CellValue<const double> Vj              = mesh_data.Vj();
    const auto& node_to_cell_matrix               = mesh->connectivity().nodeToCellMatrix();
    const auto& nl                                = mesh_data.nl();

    const NodeValue<const bool> node_is_neumann = [&] {
      NodeValue<bool> compute_node_is_neumann{mesh->connectivity()};
      compute_node_is_neumann.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list = bc.faceList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];
                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id             = face_nodes[i_node];
                  compute_node_is_neumann[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
      return compute_node_is_neumann;
    }();

    const NodeValue<const bool> node_is_dirichlet = [&] {
      NodeValue<bool> compute_node_is_dirichlet{mesh->connectivity()};
      compute_node_is_dirichlet.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& node_list = bc.nodeList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId node_id = node_list[i_node];

                compute_node_is_dirichlet[node_id] = true;
              }
            }
          },
          boundary_condition);
      }
      return compute_node_is_dirichlet;
    }();

    const NodeValue<const bool> node_is_corner = [&] {
      NodeValue<bool> compute_node_is_corner{mesh->connectivity()};
      compute_node_is_corner.fill(false);
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          if (is_boundary_node[node_id]) {
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const unsigned int i_node = node_local_number_in_its_cell[i_cell];
              const CellId cell_id      = node_to_cell[i_cell];
              const auto& cell_nodes    = cell_to_node_matrix[cell_id];
              NodeId i_node_prev;
              NodeId i_node_next;
              if (i_node == 0) {
                i_node_prev = cell_nodes.size() - 1;
              } else {
                i_node_prev = i_node - 1;
              }
              if (i_node == cell_nodes.size() - 1) {
                i_node_next = 0;
              } else {
                i_node_next = i_node + 1;
              }
              const NodeId prev_node_id = cell_to_node_matrix[cell_id][i_node_prev];
              const NodeId next_node_id = cell_to_node_matrix[cell_id][i_node_next];
              if (is_boundary_node[prev_node_id] and is_boundary_node[next_node_id]) {
                compute_node_is_corner[node_id] = true;
              }
            }
          }
        });
      return compute_node_is_corner;
    }();

    const NodeValue<const bool> node_is_angle = [&] {
      NodeValue<bool> compute_node_is_angle{mesh->connectivity()};
      compute_node_is_angle.fill(false);
      // for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          if (is_boundary_node[node_id]) {
            const auto& node_to_face = node_to_face_matrix[node_id];

            TinyVector<Dimension> n1 = zero;
            TinyVector<Dimension> n2 = zero;

            for (size_t i_face = 0; i_face < node_to_face.size(); ++i_face) {
              FaceId face_id = node_to_face[i_face];
              if (is_boundary_face[face_id]) {
                if (l2Norm(n1) == 0) {
                  n1 = nl[face_id];
                } else {
                  n2 = nl[face_id];
                }
              }
            }
            if (l2Norm(n1 - n2) > (1.E-15) and l2Norm(n1 + n2) > (1.E-15) and not node_is_corner[node_id]) {
              compute_node_is_angle[node_id] = true;
            }
            // std::cout << node_id << "  " << n1 << "  " << n2 << "  " << compute_node_is_angle[node_id] << "\n";
          }
        });
      return compute_node_is_angle;
    }();

    const NodeValue<const TinyVector<Dimension>> exterior_normal = [&] {
      NodeValue<TinyVector<Dimension>> compute_exterior_normal{mesh->connectivity()};
      compute_exterior_normal.fill(zero);
      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        if (is_boundary_node[node_id]) {
          const auto& node_to_cell                  = node_to_cell_matrix[node_id];
          const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
          for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
            const CellId cell_id      = node_to_cell[i_cell];
            const unsigned int i_node = node_local_number_in_its_cell[i_cell];
            compute_exterior_normal[node_id] += Cjr(cell_id, i_node);
          }
          const double norm_exterior_normal = l2Norm(compute_exterior_normal[node_id]);
          compute_exterior_normal[node_id] *= 1. / norm_exterior_normal;
          // std::cout << node_id << "  " << compute_exterior_normal[node_id] << "\n";
        }
      }
      return compute_exterior_normal;
    }();

    // for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
    //  std::cout << node_id << "  " << exterior_normal[node_id] << "  " << node_is_angle[node_id] << "\n";
    //};

    const FaceValue<const double> mes_l = [&] {
      if constexpr (Dimension == 1) {
        FaceValue<double> compute_mes_l{mesh->connectivity()};
        compute_mes_l.fill(1);
        return compute_mes_l;
      } else {
        return mesh_data.ll();
      }
    }();

    const NodeValue<const double> node_boundary_values = [&] {
      NodeValue<double> compute_node_boundary_values{mesh->connectivity()};
      NodeValue<double> sum_mes_l{mesh->connectivity()};
      compute_node_boundary_values.fill(0);
      sum_mes_l.fill(0);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];
                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];
                  if (not node_is_dirichlet[node_id]) {
                    if (node_is_angle[node_id]) {
                      compute_node_boundary_values[node_id] +=
                        value_list[i_face] * std::abs(dot(nl[face_id], exterior_normal[node_id]));

                    } else {
                      compute_node_boundary_values[node_id] += value_list[i_face] * mes_l[face_id];
                      sum_mes_l[node_id] += mes_l[face_id];
                    }
                  }
                }
              }

            } else if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId node_id = node_list[i_node];

                compute_node_boundary_values[node_id] = value_list[i_node];
              }
            }
          },
          boundary_condition);
      }
      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        if ((not node_is_dirichlet[node_id]) && (node_is_neumann[node_id]) and not node_is_angle[node_id]) {
          compute_node_boundary_values[node_id] /= sum_mes_l[node_id];
        }
        // std::cout << node_id << "  " << compute_node_boundary_values[node_id] << "\n";
      }
      return compute_node_boundary_values;
    }();

    const NodeValue<const double> node_boundary_prev_values = [&] {
      NodeValue<double> compute_node_boundary_values{mesh->connectivity()};
      NodeValue<double> sum_mes_l{mesh->connectivity()};
      compute_node_boundary_values.fill(0);
      sum_mes_l.fill(0);
      for (const auto& boundary_condition : boundary_prev_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, NeumannBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id   = face_list[i_face];
                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  const NodeId node_id = face_nodes[i_node];
                  if (not node_is_dirichlet[node_id]) {
                    compute_node_boundary_values[node_id] += value_list[i_face] * mes_l[face_id];
                    sum_mes_l[node_id] += mes_l[face_id];
                  }
                }
              }
            } else if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId node_id = node_list[i_node];

                compute_node_boundary_values[node_id] = value_list[i_node];
              }
            }
          },
          boundary_condition);
      }
      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        if ((not node_is_dirichlet[node_id]) && (node_is_neumann[node_id])) {
          compute_node_boundary_values[node_id] /= sum_mes_l[node_id];
        }
      }
      return compute_node_boundary_values;
    }();

    {
      CellValue<const TinyMatrix<Dimension>> cell_kappaj = cell_k->cellValues();

      const NodeValue<const TinyMatrix<Dimension>> node_kappar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa{mesh->connectivity()};
        kappa.fill(zero);
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            kappa[node_id]           = zero;
            const auto& node_to_cell = node_to_cell_matrix[node_id];
            double weight            = 0;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              double local_weight  = 1. / l2Norm(xr[node_id] - xj[cell_id]);
              kappa[node_id] += local_weight * cell_kappaj[cell_id];
              weight += local_weight;
            }
            kappa[node_id] = 1. / weight * kappa[node_id];
          });
        return kappa;
      }();

      const NodeValue<const TinyMatrix<Dimension>> node_betar = [&] {
        NodeValue<TinyMatrix<Dimension>> beta{mesh->connectivity()};
        beta.fill(zero);
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            beta[node_id]                             = zero;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id      = node_to_cell[i_cell];
              const unsigned int i_node = node_local_number_in_its_cell[i_cell];
              beta[node_id] += tensorProduct(Cjr(cell_id, i_node), xr[node_id] - xj[cell_id]);
            }
          });
        return beta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> kappar_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (not node_is_corner[node_id]) {
              kappa_invBeta[node_id] = node_kappar[node_id] * inverse(node_betar[node_id]);
            }
          });
        return kappa_invBeta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> corner_betar = [&] {
        NodeValue<TinyMatrix<Dimension>> beta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (node_is_corner[node_id]) {
              const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
              const auto& node_to_cell                  = node_to_cell_matrix[node_id];

              size_t i_cell             = 0;
              const CellId cell_id      = node_to_cell[i_cell];
              const unsigned int i_node = node_local_number_in_its_cell[i_cell];

              const auto& cell_nodes  = cell_to_node_matrix[cell_id];
              const NodeId node_id_p1 = cell_to_node_matrix[cell_id][(i_node + 1) % cell_nodes.size()];
              const NodeId node_id_p2 = cell_to_node_matrix[cell_id][(i_node + 2) % cell_nodes.size()];
              const NodeId node_id_m1 = cell_to_node_matrix[cell_id][(i_node - 1) % cell_nodes.size()];

              const TinyVector<Dimension> xj1 = 1. / 3. * (xr[node_id] + xr[node_id_m1] + xr[node_id_p2]);
              const TinyVector<Dimension> xj2 = 1. / 3. * (xr[node_id] + xr[node_id_p2] + xr[node_id_p1]);

              const TinyVector<Dimension> xrm1 = xr[node_id_m1];
              const TinyVector<Dimension> xrp1 = xr[node_id_p1];
              const TinyVector<Dimension> xrp2 = xr[node_id_p2];

              TinyVector<Dimension> Cjr1;
              TinyVector<Dimension> Cjr2;

              if (Dimension == 2) {
                Cjr1[0] = -0.5 * (xrm1[1] - xrp2[1]);
                Cjr1[1] = 0.5 * (xrm1[0] - xrp2[0]);
                Cjr2[0] = -0.5 * (xrp2[1] - xrp1[1]);
                Cjr2[1] = 0.5 * (xrp2[0] - xrp1[0]);
              } else {
                throw NotImplementedError("The scheme is not implemented in this dimension.");
              }

              beta[node_id] = tensorProduct(Cjr1, (xr[node_id] - xj1)) + tensorProduct(Cjr2, (xr[node_id] - xj2));
            }
          });
        return beta;
      }();

      const NodeValue<const TinyMatrix<Dimension>> corner_kappar_invBetar = [&] {
        NodeValue<TinyMatrix<Dimension>> kappa_invBeta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (node_is_corner[node_id]) {
              kappa_invBeta[node_id] = node_kappar[node_id] * inverse(corner_betar[node_id]);
            }
          });
        return kappa_invBeta;
      }();

      const NodeValue<const TinyVector<Dimension>> sum_Cjr = [&] {
        NodeValue<TinyVector<Dimension>> compute_sum_Cjr{mesh->connectivity()};
        compute_sum_Cjr.fill(zero);
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            const auto& node_to_cell                  = node_to_cell_matrix[node_id];
            const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
            compute_sum_Cjr[node_id]                  = zero;
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              const size_t i_node  = node_local_number_in_its_cell[i_cell];
              compute_sum_Cjr[node_id] += Cjr(cell_id, i_node);
            }
          });
        return compute_sum_Cjr;
      }();

      const NodeValue<const TinyVector<Dimension>> v = [&] {
        NodeValue<TinyVector<Dimension>> compute_v{mesh->connectivity()};
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if (is_boundary_node[node_id]) {
              compute_v[node_id] = 1. / l2Norm(node_kappar[node_id] * exterior_normal[node_id]) * node_kappar[node_id] *
                                   exterior_normal[node_id];
            }
          });
        return compute_v;
      }();

      const NodeValuePerCell<const double> theta = [&] {
        NodeValuePerCell<double> compute_theta{mesh->connectivity()};
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
            const auto& cell_nodes = cell_to_node_matrix[cell_id];
            for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
              const NodeId node_id = cell_nodes[i_node];
              if (is_boundary_node[node_id] && not node_is_corner[node_id]) {
                compute_theta(cell_id, i_node) = dot(inverse(node_betar[node_id]) * Cjr(cell_id, i_node), v[node_id]);
              }
            }
          });
        return compute_theta;
      }();

      const NodeValue<const double> sum_theta = [&] {
        NodeValue<double> compute_sum_theta{mesh->connectivity()};
        compute_sum_theta.fill(0);
        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
            if ((is_boundary_node[node_id]) && (not node_is_corner[node_id])) {
              const auto& node_to_cell                  = node_to_cell_matrix[node_id];
              const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
              compute_sum_theta[node_id]                = 0;
              for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
                const CellId cell_id = node_to_cell[i_cell];
                const size_t i_node  = node_local_number_in_its_cell[i_cell];
                compute_sum_theta[node_id] += theta(cell_id, i_node);
              }
            }
          });
        return compute_sum_theta;
      }();

      const Array<int> non_zeros{mesh->numberOfCells()};
      non_zeros.fill(4);   // Modif pour optimiser
      CRSMatrixDescriptor<double> S1(mesh->numberOfCells(), mesh->numberOfCells(), non_zeros);
      CRSMatrixDescriptor<double> S2(mesh->numberOfCells(), mesh->numberOfCells(), non_zeros);
      CRSMatrixDescriptor<double> C(mesh->numberOfCells(), mesh->numberOfCells(), non_zeros);

      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell = node_to_cell_matrix[node_id];
        for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
          const CellId cell_id_j = node_to_cell[i_cell_j];
          for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
            const CellId cell_id_k   = node_to_cell[i_cell_k];
            S1(cell_id_j, cell_id_k) = 0;
            S2(cell_id_j, cell_id_k) = 0;
            C(cell_id_j, cell_id_k)  = 0;
          }
        }
      }

      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell                   = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
        if (not is_boundary_node[node_id]) {
          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S1(cell_id_j, cell_id_k) +=
                dt * (1. - cn_coeff / 2.) *
                dot(kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
              S2(cell_id_j, cell_id_k) -=
                dt * (cn_coeff / 2.) *
                dot(kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
            }
          }
        } else if ((node_is_neumann[node_id]) && (not node_is_corner[node_id]) && (not node_is_dirichlet[node_id])) {
          const TinyMatrix<Dimension> I   = identity;
          const TinyVector<Dimension> n   = exterior_normal[node_id];
          const TinyMatrix<Dimension> nxn = tensorProduct(n, n);
          const TinyMatrix<Dimension> Q   = I - nxn;

          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S1(cell_id_j, cell_id_k) +=
                dt * (1. - cn_coeff / 2.) *
                dot(kappar_invBetar[node_id] *
                      (Cjr(cell_id_k, i_node_k) - theta(cell_id_k, i_node_k) / sum_theta[node_id] * sum_Cjr[node_id]),
                    Q * Cjr(cell_id_j, i_node_j));
              S2(cell_id_j, cell_id_k) -=
                dt * (cn_coeff / 2.) *
                dot(kappar_invBetar[node_id] *
                      (Cjr(cell_id_k, i_node_k) - theta(cell_id_k, i_node_k) / sum_theta[node_id] * sum_Cjr[node_id]),
                    Q * Cjr(cell_id_j, i_node_j));
            }
          }
        } else if ((node_is_dirichlet[node_id]) && (not node_is_corner[node_id])) {
          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S1(cell_id_j, cell_id_k) +=
                dt * (1. - cn_coeff / 2.) *
                dot(kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
              S2(cell_id_j, cell_id_k) -=
                dt * (cn_coeff / 2.) *
                dot(kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
            }
          }
        } else if (node_is_dirichlet[node_id] && node_is_corner[node_id]) {
          for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
            const CellId cell_id_j = node_to_cell[i_cell_j];
            const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

            for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
              const CellId cell_id_k = node_to_cell[i_cell_k];
              const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

              S1(cell_id_j, cell_id_k) +=
                dt * (1. - cn_coeff / 2.) *
                dot(corner_kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
              S2(cell_id_j, cell_id_k) -=
                dt * (cn_coeff / 2.) *
                dot(corner_kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k), Cjr(cell_id_j, i_node_j));
            }
          }
        }
      }

      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        if (not is_boundary_face[face_id]) {
          const auto& face_to_cell = face_to_cell_matrix[face_id];
          CellId cell_id_j         = face_to_cell[0];
          CellId cell_id_k         = face_to_cell[1];
          C(cell_id_j, cell_id_j) += dt * mes_l[face_id] / l2Norm(xj[cell_id_j] - xj[cell_id_k]);
          C(cell_id_k, cell_id_k) += dt * mes_l[face_id] / l2Norm(xj[cell_id_j] - xj[cell_id_k]);
          C(cell_id_j, cell_id_k) -= dt * mes_l[face_id] / l2Norm(xj[cell_id_j] - xj[cell_id_k]);
          C(cell_id_k, cell_id_j) -= dt * mes_l[face_id] / l2Norm(xj[cell_id_j] - xj[cell_id_k]);
        } else {
          const auto& face_to_cell = face_to_cell_matrix[face_id];
          CellId cell_id_j         = face_to_cell[0];
          C(cell_id_j, cell_id_j) += dt * mes_l[face_id] / l2Norm(xj[cell_id_j] - xl[face_id]);
        }
      }

      const double epsilon = 0.5;

      for (CellId cell_id_j = 0; cell_id_j < mesh->numberOfCells(); ++cell_id_j) {
        for (CellId cell_id_k = 0; cell_id_k < mesh->numberOfCells(); ++cell_id_k) {
          if (S1(cell_id_j, cell_id_k) != 0 or C(cell_id_j, cell_id_k) != 0) {
            S1(cell_id_j, cell_id_k) = (1. - epsilon) * S1(cell_id_j, cell_id_k) + epsilon * C(cell_id_j, cell_id_k);
          }
        }
      }

      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        S1(cell_id, cell_id) += Vj[cell_id];
        S2(cell_id, cell_id) += Vj[cell_id];
      }

      CellValue<const double> fj      = f->cellValues();
      CellValue<const double> fj_prev = f_prev->cellValues();
      CellValue<const double> Pj      = P->cellValues();

      Vector<double> Ph{mesh->numberOfCells()};
      Ph = zero;
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        Ph[cell_id] = Pj[cell_id];
      };

      CRSMatrix A1{S1.getCRSMatrix()};
      CRSMatrix A2{S2.getCRSMatrix()};

      // std::cout << A1 << "\n";

      Vector<double> b{mesh->numberOfCells()};
      b = zero;
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        b[cell_id] = Vj[cell_id] * fj[cell_id];
      };

      Vector<double> b_prev{mesh->numberOfCells()};
      b_prev = zero;
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        b_prev[cell_id] = Vj[cell_id] * fj_prev[cell_id];
      };

      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell                   = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
        const TinyMatrix<Dimension> I              = identity;
        const TinyVector<Dimension> n              = exterior_normal[node_id];
        const TinyMatrix<Dimension> nxn            = tensorProduct(n, n);
        const TinyMatrix<Dimension> Q              = I - nxn;
        if ((node_is_neumann[node_id]) && (not node_is_dirichlet[node_id])) {
          if ((is_boundary_node[node_id]) and (not node_is_corner[node_id])) {
            for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
              const CellId cell_id = node_to_cell[i_cell];
              const size_t i_node  = node_local_number_in_its_cells[i_cell];

              b[cell_id] += node_boundary_values[node_id] *
                            (1. / (sum_theta[node_id] * l2Norm(node_kappar[node_id] * n)) *
                               dot(kappar_invBetar[node_id] * sum_Cjr[node_id], Q * Cjr(cell_id, i_node)) +
                             dot(Cjr(cell_id, i_node), n));
              b_prev[cell_id] += node_boundary_prev_values[node_id] *
                                 (1. / (sum_theta[node_id] * l2Norm(node_kappar[node_id] * n)) *
                                    dot(kappar_invBetar[node_id] * sum_Cjr[node_id], Q * Cjr(cell_id, i_node)) +
                                  dot(Cjr(cell_id, i_node), n));
            }
          } else if (node_is_corner[node_id]) {
            const auto& node_to_face = node_to_face_matrix[node_id];
            const CellId cell_id     = node_to_cell[0];
            double sum_mes_l         = 0;
            for (size_t i_face = 0; i_face < node_to_face.size(); ++i_face) {
              FaceId face_id = node_to_face[i_face];
              sum_mes_l += mes_l[face_id];
            }
            b[cell_id] += 0.5 * node_boundary_values[node_id] * sum_mes_l;
            b_prev[cell_id] += 0.5 * node_boundary_prev_values[node_id] * sum_mes_l;
          }
        } else if (node_is_dirichlet[node_id]) {
          if (not node_is_corner[node_id]) {
            for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
              const CellId cell_id_j = node_to_cell[i_cell_j];
              const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

              for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
                const CellId cell_id_k = node_to_cell[i_cell_k];
                const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

                b[cell_id_j] += dot(node_boundary_values[node_id] * kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k),
                                    Cjr(cell_id_j, i_node_j));
                b_prev[cell_id_j] +=
                  dot(node_boundary_prev_values[node_id] * kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k),
                      Cjr(cell_id_j, i_node_j));
              }
            }

          } else if (node_is_corner[node_id]) {
            for (size_t i_cell_j = 0; i_cell_j < node_to_cell.size(); ++i_cell_j) {
              const CellId cell_id_j = node_to_cell[i_cell_j];
              const size_t i_node_j  = node_local_number_in_its_cells[i_cell_j];

              for (size_t i_cell_k = 0; i_cell_k < node_to_cell.size(); ++i_cell_k) {
                const CellId cell_id_k = node_to_cell[i_cell_k];
                const size_t i_node_k  = node_local_number_in_its_cells[i_cell_k];

                b[cell_id_j] +=
                  dot(node_boundary_values[node_id] * corner_kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k),
                      Cjr(cell_id_j, i_node_j));
                b_prev[cell_id_j] +=
                  dot(node_boundary_prev_values[node_id] * corner_kappar_invBetar[node_id] * Cjr(cell_id_k, i_node_k),
                      Cjr(cell_id_j, i_node_j));
              }
            }
          }
        }
      };

      Vector<double> T{mesh->numberOfCells()};
      T = zero;

      // Array<double> Ph2  = Ph;
      Vector<double> A2P = A2 * Ph;

      Vector<double> B{mesh->numberOfCells()};
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          B[cell_id] = dt * ((1. - cn_coeff / 2.) * b[cell_id] + cn_coeff / 2. * b_prev[cell_id]) + A2P[cell_id];
        });

      // std::cout << "g = " << node_boundary_values << "\n";

      NodeValue<TinyVector<Dimension>> ur{mesh->connectivity()};
      ur.fill(zero);
      for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
        const auto& node_to_cell                  = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
        if (not is_boundary_node[node_id]) {
          for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
            const CellId cell_id = node_to_cell[i_cell];
            const size_t i_node  = node_local_number_in_its_cell[i_cell];
            ur[node_id] += Pj[cell_id] * Cjr(cell_id, i_node);
          }
          ur[node_id] = -inverse(node_betar[node_id]) * ur[node_id];
          // std::cout << node_id << "; ur = " << ur[node_id] << "\n";
        } else if (is_boundary_node[node_id] and node_is_dirichlet[node_id]) {
          for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
            const CellId cell_id = node_to_cell[i_cell];
            const size_t i_node  = node_local_number_in_its_cell[i_cell];
            ur[node_id] += (node_boundary_values[node_id] - Pj[cell_id]) * Cjr(cell_id, i_node);
          }
          if (not node_is_corner[node_id]) {
            ur[node_id] = inverse(node_betar[node_id]) * ur[node_id];
          } else {
            ur[node_id] = inverse(corner_betar[node_id]) * ur[node_id];
          }
          // std::cout << "bord : " << node_id << "; ur = " << ur[node_id] << "\n";
        }
      }

      LinearSolver solver;
      solver.solveLocalSystem(A1, T, B);

      m_solution     = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
      auto& solution = *m_solution;
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { solution[cell_id] = T[cell_id]; });
    }
  }
};

std::shared_ptr<const IDiscreteFunction>
ScalarNodalSchemeHandler::solution() const
{
  return m_scheme->getSolution();
}

ScalarNodalSchemeHandler::ScalarNodalSchemeHandler(
  const std::shared_ptr<const IDiscreteFunction>& cell_k,
  const std::shared_ptr<const IDiscreteFunction>& f,
  const std::shared_ptr<const IDiscreteFunction>& f_prev,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_prev_descriptor_list,
  const std::shared_ptr<const IDiscreteFunction>& P,
  const double& dt,
  const double& cn_coeff)
{
  const std::shared_ptr i_mesh = getCommonMesh({cell_k, f});
  if (not i_mesh) {
    throw NormalError("primal discrete functions are not defined on the same mesh");
  }

  checkDiscretizationType({cell_k, f}, DiscreteFunctionType::P0);

  switch (i_mesh->dimension()) {
  case 1: {
    throw NormalError("The scheme is not implemented in dimension 1");
    break;
  }
  case 2: {
    using MeshType                   = Mesh<Connectivity<2>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<2, TinyMatrix<2>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    m_scheme =
      std::make_unique<ScalarNodalScheme<2>>(mesh, std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f_prev),
                                             bc_descriptor_list, bc_prev_descriptor_list,
                                             std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(P), dt,
                                             cn_coeff);
    break;
  }
  case 3: {
    throw NormalError("The scheme is not implemented in dimension 3");
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

ScalarNodalSchemeHandler::~ScalarNodalSchemeHandler() = default;