#include <scheme/DiscreteFunctionUtils.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/IMesh.hpp>
#include <mesh/Mesh.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <utils/Stringify.hpp>

std::shared_ptr<const IMesh>
getCommonMesh(const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list)
{
  std::shared_ptr<const IMesh> i_mesh;
  bool is_same_mesh = true;
  for (const auto& discrete_function_variant : discrete_function_variant_list) {
    std::visit(
      [&](auto&& discrete_function) {
        if (not i_mesh.use_count()) {
          i_mesh = discrete_function.mesh();
        } else {
          if (i_mesh != discrete_function.mesh()) {
            is_same_mesh = false;
          }
        }
      },
      discrete_function_variant->discreteFunction());
  }
  if (not is_same_mesh) {
    i_mesh.reset();
  }
  return i_mesh;
}

bool
hasSameMesh(const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list)
{
  std::shared_ptr<const IMesh> i_mesh;
  bool is_same_mesh = true;
  for (const auto& discrete_function_variant : discrete_function_variant_list) {
    std::visit(
      [&](auto&& discrete_function) {
        if (not i_mesh.use_count()) {
          i_mesh = discrete_function.mesh();
        } else {
          if (i_mesh != discrete_function.mesh()) {
            is_same_mesh = false;
          }
        }
      },
      discrete_function_variant->discreteFunction());
  }

  return is_same_mesh;
}

template <typename MeshType, typename DiscreteFunctionT>
std::shared_ptr<const DiscreteFunctionVariant>
shallowCopy(const std::shared_ptr<const MeshType>& mesh, const DiscreteFunctionT& f)
{
  const std::shared_ptr function_mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());

  if (mesh->shared_connectivity() != function_mesh->shared_connectivity()) {
    throw NormalError("cannot shallow copy when connectivity changes");
  }

  if constexpr (std::is_same_v<MeshType, typename DiscreteFunctionT::MeshType>) {
    if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
      return std::make_shared<DiscreteFunctionVariant>(DiscreteFunctionT(mesh, f.cellValues()));
    } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
      return std::make_shared<DiscreteFunctionVariant>(DiscreteFunctionT(mesh, f.cellArrays()));
    } else {
      throw UnexpectedError("invalid discrete function type");
    }
  } else {
    throw UnexpectedError("invalid mesh types");
  }
}

std::shared_ptr<const DiscreteFunctionVariant>
shallowCopy(const std::shared_ptr<const IMesh>& mesh,
            const std::shared_ptr<const DiscreteFunctionVariant>& discrete_function_variant)
{
  return std::visit(
    [&](auto&& f) {
      if (mesh == f.mesh()) {
        return discrete_function_variant;
      } else if (mesh->dimension() != f.mesh()->dimension()) {
        throw NormalError("incompatible mesh dimensions");
      }

      switch (mesh->dimension()) {
      case 1: {
        return shallowCopy(std::dynamic_pointer_cast<const Mesh<Connectivity<1>>>(mesh), f);
      }
      case 2: {
        return shallowCopy(std::dynamic_pointer_cast<const Mesh<Connectivity<2>>>(mesh), f);
      }
      case 3: {
        return shallowCopy(std::dynamic_pointer_cast<const Mesh<Connectivity<3>>>(mesh), f);
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("invalid mesh dimension");
      }
        // LCOV_EXCL_STOP
      }
    },
    discrete_function_variant->discreteFunction());
}
