#include <scheme/ScalarNodalScheme.hpp>

#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>


class ScalarNodalSchemeHandler::IScalarNodalScheme
{
 public:
  virtual std::shared_ptr<const IDiscreteFunction> getSolution() const = 0;

  IScalarNodalScheme()          = default;
  virtual ~IScalarNodalScheme() = default;
};

template <size_t Dimension>
class ScalarNodalSchemeHandler::ScalarNodalScheme : public ScalarNodalSchemeHandler::IScalarNodalScheme
{
 private:
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_solution;

  class DirichletBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    DirichletBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~DirichletBoundaryCondition() = default;
  };

  class NeumannBoundaryCondition
  {
   private:
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    NeumannBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~NeumannBoundaryCondition() = default;
  };

  class FourierBoundaryCondition
  {
   private:
    const Array<const double> m_coef_list;
    const Array<const double> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const double>&
    valueList() const
    {
      return m_value_list;
    }

    const Array<const double>&
    coefList() const
    {
      return m_coef_list;
    }

   public:
    FourierBoundaryCondition(const Array<const FaceId>& face_list,
                             const Array<const double>& coef_list,
                             const Array<const double>& value_list)
      : m_coef_list{coef_list}, m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_coef_list.size() == m_face_list.size());
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~FourierBoundaryCondition() = default;
  };


 public:
  std::shared_ptr<const IDiscreteFunction>
  getSolution() const final
  {
    return m_solution;
  }

  ScalarNodalScheme(const std::shared_ptr<const MeshType>& mesh,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& alpha,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k_bound,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyMatrix<Dimension>>>& cell_k,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& f,
                      const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {

    using BoundaryCondition = std::variant<DirichletBoundaryCondition, FourierBoundaryCondition,
                                           NeumannBoundaryCondition>;

    using BoundaryConditionList = std::vector<BoundaryCondition>;

    BoundaryConditionList boundary_condition_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_condition_list.push_back(DirichletBoundaryCondition{mesh_face_boundary.faceList(), value_list});

        } else {
          throw NotImplementedError("Dirichlet BC in 1d");
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::fourier: {
        throw NotImplementedError("NIY");
        break;
      }
      case IBoundaryConditionDescriptor::Type::neumann: {
        const NeumannBoundaryConditionDescriptor& neumann_bc_descriptor =
          dynamic_cast<const NeumannBoundaryConditionDescriptor&>(*bc_descriptor);

        if constexpr (Dimension > 1) {
          MeshFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFaceBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());

          const FunctionSymbolId g_id = neumann_bc_descriptor.rhsSymbolId();
          MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

          Array<const double> value_list =
            InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                      mesh_data.xl(),
                                                                                                      mesh_face_boundary
                                                                                                        .faceList());

          boundary_condition_list.push_back(NeumannBoundaryCondition{mesh_face_boundary.faceList(), value_list});

        } else {
          throw NotImplementedError("Dirichlet BC in 1d");
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for heat equation";
        throw NormalError(error_msg.str());
      }
    }

    if constexpr (Dimension > 1) {
      const CellValue<const size_t> cell_dof_number = [&] {
        CellValue<size_t> compute_cell_dof_number{mesh->connectivity()};
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { compute_cell_dof_number[cell_id] = cell_id; });
        return compute_cell_dof_number;
      }();
      size_t number_of_dof = mesh->numberOfCells();

      const FaceValue<const size_t> face_dof_number = [&] {
        FaceValue<size_t> compute_face_dof_number{mesh->connectivity()};
        compute_face_dof_number.fill(std::numeric_limits<size_t>::max());
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NeumannBoundaryCondition>) or
                            (std::is_same_v<T, FourierBoundaryCondition>) or
                            (std::is_same_v<T, DirichletBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id = face_list[i_face];
                  if (compute_face_dof_number[face_id] != std::numeric_limits<size_t>::max()) {
                    std::ostringstream os;
                    os << "The face " << face_id << " is used at least twice for boundary conditions";
                    throw NormalError(os.str());
                  } else {
                    compute_face_dof_number[face_id] = number_of_dof++;
                  }
                }
              }
            },
            boundary_condition);
        }

        return compute_face_dof_number;
      }();

      const FaceValue<const bool> face_is_neumann = [&] {
        FaceValue<bool> face_is_neumann{mesh->connectivity()};
        face_is_neumann.fill(false);
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NeumannBoundaryCondition>) or
                            (std::is_same_v<T, FourierBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id     = face_list[i_face];
                  face_is_neumann[face_id] = true;
                }
              }
            },
            boundary_condition);
        }

        return face_is_neumann;
      }();


      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

      const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
      const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
      const NodeValue<const TinyVector<Dimension>>& xr = mesh_data.xr();

      const auto& node_to_face_matrix = mesh->connectivity().nodeToFaceMatrix();
      const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
      const auto& node_local_numbers_in_their_cells = mesh.connectivity().nodeLocalNumbersInTheirCells();

      const NodeValuePerCell<const TinyVector<Dimension>>& Cjr = mesh_data.Cjr();

      {


        CellValue<const TinyMatrix<Dimension>> cell_kappaj  = cell_k->cellValues();
        CellValue<const TinyMatrix<Dimension>> cell_kappajb = cell_k_bound->cellValues();

        const CellValue<const double> Vj = mesh_data.Vj();
        const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();

        const FaceValue<const double> mes_l = [&] {
          if constexpr (Dimension == 1) {
            FaceValue<double> compute_mes_l{mesh->connectivity()};
            compute_mes_l.fill(1);
            return compute_mes_l;
          } else {
            return mesh_data.ll();
          }
        }();

        const NodeValue<const TinyMatrix<Dimension>> node_kappar = [&] {
          NodeValue<const TinyMatrix<Dimension>> kappa;
          parallel_for(
              mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
                kappa[node_id] = zero;
                const auto& node_to_cell = node_to_cell_matrix[node_id];
                double weight = 0;
                for (size_t j = 0; j < node_to_cell.size(); ++j) {
                  const CellId J       = node_to_cell[j];
                  double local_weight = 1 / l2Norm(xr[node_id]-xj[J]);
                  kappa[node_id] += cell_kappaj[J] * local_weight;
                  weight += local_weight;
                }
                kappa[node_id] /= weight;
            }
          );
          return kappa;
        }();

        const NodeValue<const TinyMatrix<Dimension>> node_kapparb = [&] {
          NodeValue<const TinyMatrix<Dimension>> kappa;
          parallel_for(
              mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
                kappa[node_id] = zero;
                const auto& node_to_cell = node_to_cell_matrix[node_id];
                double weight = 0;
                for (size_t j = 0; j < node_to_cell.size(); ++j) {
                  const CellId J       = node_to_cell[j];
                  double local_weight = 1 / l2Norm(xr[node_id]-xj[J]);
                  kappa[node_id] += cell_kappajb[J] * local_weight;
                  weight += local_weight;
                }
                kappa[node_id] /= weight;
            });
          return kappa;
        }();

        const NodeValue<const TinyMatrix<Dimension>> node_betar = [&] {
          NodeValue<const TinyMatrix<Dimension>> beta;
          parallel_for(
            mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
              const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(node_id);
              const auto& node_to_cell = node_to_cell_matrix[node_id];
              beta[node_id] = zero;
                for (size_t j = 0; j < node_to_cell.size(); ++j) {
                  const CellId J = node_to_cell[j];
                  const unsigned int R = node_local_number_in_its_cell[j];
                  beta[node_id] += tensorProduct(Cjr(J,R),xr[node_id]-xj[J]);
                }
            });
            return beta;
        }();


        FaceValue<const double> alpha_l = [&] {
          FaceValue<double> alpha_j{mesh->connectivity()};

          parallel_for(
            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
              alpha_j[face_id] = 1;
            });

          return alpha_j;
        }();

        FaceValue<const double> alphab_l = [&] {
          FaceValue<double> alpha_lb{mesh->connectivity()};

          parallel_for(
            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
              alpha_lb[face_id] = 1; //Refaire
            });

          return alpha_lb;
        }();


        const Array<int> non_zeros{number_of_dof};
        non_zeros.fill(Dimension); //Modif pour optimiser
        CRSMatrixDescriptor<double> S(number_of_dof, number_of_dof, non_zeros);

        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          const auto& face_to_cell = face_to_cell_matrix[face_id];
          const double beta_l             = l2Norm(dualClj[face_id]) * alpha_l[face_id] * mes_l[face_id];
          const double betab_l            = l2Norm(dualClj[face_id]) * alphab_l[face_id] * mes_l[face_id];
          for (size_t i_cell = 0; i_cell < face_to_cell.size(); ++i_cell) {
            const CellId cell_id1 = face_to_cell[i_cell];
            for (size_t j_cell = 0; j_cell < face_to_cell.size(); ++j_cell) {
              const CellId cell_id2 = face_to_cell[j_cell];
              if (i_cell == j_cell) {
                S(cell_dof_number[cell_id1], cell_dof_number[cell_id2]) += beta_l;
                if (face_is_neumann[face_id]) {
                  S(face_dof_number[face_id], cell_dof_number[cell_id2]) -= betab_l;
                }
              } else {
                S(cell_dof_number[cell_id1], cell_dof_number[cell_id2]) -= beta_l;
              }
            }
          }
        }

        for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
          const size_t j = cell_dof_number[cell_id];
          S(j, j) += (*alpha)[cell_id] * Vj[cell_id];
        }


        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          if (face_is_dirichlet[face_id]) {
            S(face_dof_number[face_id], face_dof_number[face_id]) += 1;
          }
        }

        CellValue<const double> fj = f->cellValues();

        CRSMatrix A{S.getCRSMatrix()};
        Vector<double> b{number_of_dof};
        b = zero;
        for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
          b[cell_dof_number[cell_id]] = fj[cell_id] * Vj[cell_id];
        }
        // Dirichlet on b^L_D
        {
          for (const auto& boundary_condition : boundary_condition_list) {
            std::visit(
              [&](auto&& bc) {
                using T = std::decay_t<decltype(bc)>;
                if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
                  const auto& face_list  = bc.faceList();
                  const auto& value_list = bc.valueList();
                  for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                    const FaceId face_id = face_list[i_face];
                    b[face_dof_number[face_id]] += value_list[i_face];
                  }
                }
              },
              boundary_condition);
          }
        }
        // EL b^L
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NeumannBoundaryCondition>) or
                            (std::is_same_v<T, FourierBoundaryCondition>)) {
                const auto& face_list  = bc.faceList();
                const auto& value_list = bc.valueList();
                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  FaceId face_id = face_list[i_face];
                  b[face_dof_number[face_id]] += mes_l[face_id] * value_list[i_face];
                }
              }
            },
            boundary_condition);
        }

        Vector<double> T{number_of_dof};
        T = zero;

        LinearSolver solver;
        solver.solveLocalSystem(A, T, b);

        m_solution     = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
        auto& solution = *m_solution;
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { solution[cell_id] = T[cell_dof_number[cell_id]]; });
      }
    }
  }
};

std::shared_ptr<const IDiscreteFunction>
ScalarNodalSchemeHandler::solution() const
{
  return m_scheme->getSolution();
}

ScalarNodalSchemeHandler::ScalarNodalSchemeHandler(
  const std::shared_ptr<const IDiscreteFunction>& alpha,
  const std::shared_ptr<const IDiscreteFunction>& k_bound,
  const std::shared_ptr<const IDiscreteFunction>& cell_k,
  const std::shared_ptr<const IDiscreteFunction>& f,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
{
  const std::shared_ptr i_mesh = getCommonMesh({alpha, f});
  if (not i_mesh) {
    throw NormalError("primal discrete functions are not defined on the same mesh");
  }
  const std::shared_ptr i_dual_mesh = getCommonMesh({cell_k_bound, cell_k});
  if (not i_dual_mesh) {
    throw NormalError("dual discrete functions are not defined on the same mesh");
  }
  checkDiscretizationType({alpha, cell_k_bound, cell_k, f}, DiscreteFunctionType::P0);

  switch (i_mesh->dimension()) {
  case 1: {
    using MeshType                   = Mesh<Connectivity<1>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<1, TinyMatrix<1>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<1, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_scheme =
      std::make_unique<ScalarNodalScheme<1>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                               std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_bound),
                                               std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                               bc_descriptor_list);
    break;
  }
  case 2: {
    using MeshType                   = Mesh<Connectivity<2>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<2, TinyMatrix<2>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_scheme =
      std::make_unique<ScalarNodalScheme<2>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                               std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_bound),
                                               std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                               bc_descriptor_list);
    break;
  }
  case 3: {
    using MeshType                   = Mesh<Connectivity<3>>;
    using DiscreteTensorFunctionType = DiscreteFunctionP0<3, TinyMatrix<3>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<3, double>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_scheme =
      std::make_unique<ScalarNodalScheme<3>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                               std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k_bound),
                                               std::dynamic_pointer_cast<const DiscreteTensorFunctionType>(cell_k),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(f),
                                               bc_descriptor_list);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

ScalarNodalSchemeHandler::~ScalarNodalSchemeHandler() = default;