#include <mesh/MeshRelaxer.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/Mesh.hpp>

template <typename ConnectivityType>
std::shared_ptr<const Mesh<ConnectivityType>>
MeshRelaxer::_relax(const Mesh<ConnectivityType>& source_mesh,
                    const Mesh<ConnectivityType>& destination_mesh,
                    const double& theta) const
{
  if (source_mesh.shared_connectivity() == destination_mesh.shared_connectivity()) {
    const ConnectivityType& connectivity = source_mesh.connectivity();
    NodeValue<TinyVector<ConnectivityType::Dimension>> theta_xr{connectivity};
    const NodeValue<const TinyVector<ConnectivityType::Dimension>> source_xr      = source_mesh.xr();
    const NodeValue<const TinyVector<ConnectivityType::Dimension>> destination_xr = destination_mesh.xr();

    const double one_minus_theta = 1 - theta;
    parallel_for(
      connectivity.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) {
        theta_xr[node_id] = one_minus_theta * source_xr[node_id] + theta * destination_xr[node_id];
      });

    return std::make_shared<Mesh<ConnectivityType>>(source_mesh.shared_connectivity(), theta_xr);
  } else {
    throw NormalError("relaxed meshes must share the same connectivity");
  }
}

std::shared_ptr<const IMesh>
MeshRelaxer::relax(const std::shared_ptr<const IMesh>& p_source_mesh,
                   const std::shared_ptr<const IMesh>& p_destination_mesh,
                   const double& theta) const
{
  if (p_source_mesh->dimension() != p_destination_mesh->dimension()) {
    throw NormalError("incompatible mesh dimensions");
  } else {
    switch (p_source_mesh->dimension()) {
    case 1: {
      using MeshType = Mesh<Connectivity<1>>;
      return this->_relax(dynamic_cast<const MeshType&>(*p_source_mesh),
                          dynamic_cast<const MeshType&>(*p_destination_mesh), theta);
    }
    case 2: {
      using MeshType = Mesh<Connectivity<2>>;
      return this->_relax(dynamic_cast<const MeshType&>(*p_source_mesh),
                          dynamic_cast<const MeshType&>(*p_destination_mesh), theta);
    }
    case 3: {
      using MeshType = Mesh<Connectivity<3>>;
      return this->_relax(dynamic_cast<const MeshType&>(*p_source_mesh),
                          dynamic_cast<const MeshType&>(*p_destination_mesh), theta);
    }
    default: {
      throw UnexpectedError("invalid mesh dimension");
    }
    }
  }
}
