#include <scheme/DiscreteFunctionUtils.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/IMesh.hpp>
#include <mesh/Mesh.hpp>
#include <scheme/DiscreteFunctionP0.hpp>

template <size_t Dimension, typename DataType>
std::shared_ptr<const IDiscreteFunction>
shallowCopy(const std::shared_ptr<const Mesh<Connectivity<Dimension>>>& mesh,
            const std::shared_ptr<const DiscreteFunctionP0<Dimension, DataType>>& discrete_function)
{
  return std::make_shared<DiscreteFunctionP0<Dimension, DataType>>(mesh, discrete_function->cellValues());
}

template <size_t Dimension>
std::shared_ptr<const IDiscreteFunction>
shallowCopy(const std::shared_ptr<const Mesh<Connectivity<Dimension>>>& mesh,
            const std::shared_ptr<const IDiscreteFunction>& discrete_function)
{
  const std::shared_ptr function_mesh =
    std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(discrete_function->mesh());

  if (mesh->shared_connectivity() != function_mesh->shared_connectivity()) {
    throw NormalError("incompatible connectivities");
  }

  switch (discrete_function->descriptor().type()) {
  case DiscreteFunctionType::P0: {
    switch (discrete_function->dataType()) {
    case ASTNodeDataType::double_t: {
      return shallowCopy(mesh,
                         std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, double>>(discrete_function));
    }
    case ASTNodeDataType::vector_t: {
      switch (discrete_function->dataType().dimension()) {
      case 1: {
        return shallowCopy(mesh, std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>>(
                                   discrete_function));
      }
      case 2: {
        return shallowCopy(mesh, std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>>(
                                   discrete_function));
      }
      case 3: {
        return shallowCopy(mesh, std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>>(
                                   discrete_function));
      }
      default: {
        throw UnexpectedError("invalid data vector dimension: " +
                              std::to_string(discrete_function->dataType().dimension()));
      }
      }
    }
    case ASTNodeDataType::matrix_t: {
      if (discrete_function->dataType().nbRows() != discrete_function->dataType().nbColumns()) {
        throw UnexpectedError(
          "invalid data matrix dimensions: " + std::to_string(discrete_function->dataType().nbRows()) + "x" +
          std::to_string(discrete_function->dataType().nbColumns()));
      }
      switch (discrete_function->dataType().nbRows()) {
      case 1: {
        return shallowCopy(mesh, std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>>(
                                   discrete_function));
      }
      case 2: {
        return shallowCopy(mesh, std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>>(
                                   discrete_function));
      }
      case 3: {
        return shallowCopy(mesh, std::dynamic_pointer_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>>(
                                   discrete_function));
      }
      default: {
        throw UnexpectedError(
          "invalid data matrix dimensions: " + std::to_string(discrete_function->dataType().nbRows()) + "x" +
          std::to_string(discrete_function->dataType().nbColumns()));
      }
      }
    }
    default: {
      throw UnexpectedError("invalid kind of P0 function: invalid data type");
    }
    }
  }
  default: {
    throw NormalError("invalid discretization type");
  }
  }
}

std::shared_ptr<const IDiscreteFunction>
shallowCopy(const std::shared_ptr<const IMesh>& mesh, const std::shared_ptr<const IDiscreteFunction>& discrete_function)
{
  if (mesh == discrete_function->mesh()) {
    return discrete_function;
  } else if (mesh->dimension() != discrete_function->mesh()->dimension()) {
    throw NormalError("incompatible mesh dimensions");
  }

  switch (mesh->dimension()) {
  case 1: {
    return shallowCopy(std::dynamic_pointer_cast<const Mesh<Connectivity<1>>>(mesh), discrete_function);
  }
  case 2: {
    return shallowCopy(std::dynamic_pointer_cast<const Mesh<Connectivity<2>>>(mesh), discrete_function);
  }
  case 3: {
    return shallowCopy(std::dynamic_pointer_cast<const Mesh<Connectivity<3>>>(mesh), discrete_function);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}
