#include <mesh/MeshTransformer.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/Mesh.hpp>

#include <language/utils/EvaluateAtPoints.hpp>

template <typename OutputType, typename InputType>
class MeshTransformer::MeshTransformation<OutputType(InputType)>
{
  static constexpr size_t Dimension = OutputType::Dimension;

 public:
  static inline std::shared_ptr<Mesh<Connectivity<Dimension>>>
  transform(const FunctionSymbolId& function_symbol_id, std::shared_ptr<const IMesh> p_mesh)
  {
    using MeshType             = Mesh<Connectivity<Dimension>>;
    const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);

    NodeValue<OutputType> xr(given_mesh.connectivity());
    NodeValue<const InputType> given_xr = given_mesh.xr();
    EvaluateAtPoints<OutputType(InputType)>::evaluateTo(function_symbol_id, given_xr, xr);

    return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
  }
};

std::shared_ptr<const IMesh>
MeshTransformer::transform(const FunctionSymbolId& function_id, std::shared_ptr<const IMesh> p_mesh)

{
  switch (p_mesh->dimension()) {
  case 1: {
    using TransformT = TinyVector<1>(TinyVector<1>);
    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
  }
  case 2: {
    using TransformT = TinyVector<2>(TinyVector<2>);
    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
  }
  case 3: {
    using TransformT = TinyVector<3>(TinyVector<3>);
    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}
