#include <mesh/CartesianMeshBuilder.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/LogicalConnectivityBuilder.hpp>
#include <utils/Array.hpp>
#include <utils/Messenger.hpp>

template <size_t Dimension>
NodeValue<TinyVector<Dimension>>
CartesianMeshBuilder::_getNodeCoordinates(const TinyVector<Dimension>&,
                                          const TinyVector<Dimension>&,
                                          const TinyVector<Dimension, uint64_t>&,
                                          const IConnectivity&) const
{
  static_assert(Dimension <= 3, "invalid dimension");
}

template <>
NodeValue<TinyVector<1>>
CartesianMeshBuilder::_getNodeCoordinates(const TinyVector<1>& a,
                                          const TinyVector<1>& b,
                                          const TinyVector<1, uint64_t>& cell_size,
                                          const IConnectivity& connectivity) const
{
  const TinyVector<1> h{(b[0] - a[0]) / cell_size[0]};
  NodeValue<TinyVector<1>> xr(connectivity);
  parallel_for(
    connectivity.numberOfNodes(), PUGS_LAMBDA(NodeId r) { xr[r] = a + r * h; });

  return xr;
}

template <>
NodeValue<TinyVector<2>>
CartesianMeshBuilder::_getNodeCoordinates(const TinyVector<2>& a,
                                          const TinyVector<2>& b,
                                          const TinyVector<2, uint64_t>& cell_size,
                                          const IConnectivity& connectivity) const
{
  const TinyVector<2> h{(b[0] - a[0]) / cell_size[0], (b[1] - a[1]) / cell_size[1]};

  const TinyVector<2, uint64_t> node_size{cell_size[0] + 1, cell_size[1] + 1};

  const auto node_logic_id = [&](size_t r) {
    const uint64_t r0 = r / node_size[1];
    const uint64_t r1 = r % node_size[1];
    return TinyVector<2, uint64_t>{r0, r1};
  };

  NodeValue<TinyVector<2>> xr(connectivity);

  parallel_for(
    connectivity.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
      const TinyVector<2, uint64_t> node_index = node_logic_id(r);
      for (size_t i = 0; i < 2; ++i) {
        xr[r][i] = a[i] + node_index[i] * h[i];
      }
    });

  return xr;
}

template <>
NodeValue<TinyVector<3>>
CartesianMeshBuilder::_getNodeCoordinates(const TinyVector<3>& a,
                                          const TinyVector<3>& b,
                                          const TinyVector<3, uint64_t>& cell_size,
                                          const IConnectivity& connectivity) const
{
  const TinyVector<3, uint64_t> node_size = [&] {
    TinyVector node_size{cell_size};
    for (size_t i = 0; i < 3; ++i) {
      node_size[i] += 1;
    }
    return node_size;
  }();

  const auto node_logic_id = [&](size_t r) {
    const size_t slice1  = node_size[1] * node_size[2];
    const size_t& slice2 = node_size[2];
    const uint64_t r0    = r / slice1;
    const uint64_t r1    = (r - r0 * slice1) / slice2;
    const uint64_t r2    = r - (r0 * slice1 + r1 * slice2);
    return TinyVector<3, uint64_t>{r0, r1, r2};
  };

  const TinyVector<3> h{(b[0] - a[0]) / cell_size[0], (b[1] - a[1]) / cell_size[1], (b[2] - a[2]) / cell_size[2]};

  NodeValue<TinyVector<3>> xr(connectivity);
  parallel_for(
    connectivity.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
      const TinyVector<3, uint64_t> node_index = node_logic_id(r);
      for (size_t i = 0; i < 3; ++i) {
        xr[r][i] = a[i] + node_index[i] * h[i];
      }
    });

  return xr;
}

template <size_t Dimension>
void
CartesianMeshBuilder::_buildCartesianMesh(const TinyVector<Dimension>& a,
                                          const TinyVector<Dimension>& b,
                                          const TinyVector<Dimension, uint64_t>& cell_size)
{
  static_assert(Dimension <= 3, "unexpected dimension");

  LogicalConnectivityBuilder logical_connectivity_builder{cell_size};

  using ConnectivityType = Connectivity<Dimension>;

  std::shared_ptr<const ConnectivityType> p_connectivity =
    std::dynamic_pointer_cast<const ConnectivityType>(logical_connectivity_builder.connectivity());
  const ConnectivityType& connectivity = *p_connectivity;

  NodeValue<TinyVector<Dimension>> xr = _getNodeCoordinates(a, b, cell_size, connectivity);

  m_mesh = std::make_shared<Mesh<ConnectivityType>>(p_connectivity, xr);
}

template <size_t Dimension>
CartesianMeshBuilder::CartesianMeshBuilder(const TinyVector<Dimension>& a,
                                           const TinyVector<Dimension>& b,
                                           const TinyVector<Dimension, uint64_t>& size)
{
  TinyVector lenght = b - a;
  for (size_t i = 0; i < Dimension; ++i) {
    if (lenght[i] == 0) {
      throw NormalError("invalid box definition corners share a component");
    }
  }

  if (parallel::rank() == 0) {
    TinyVector<Dimension> corner0 = a;
    TinyVector<Dimension> corner1 = b;

    for (size_t i = 0; i < Dimension; ++i) {
      if (corner0[i] > corner1[i]) {
        std::swap(corner0[i], corner1[i]);
      }
    }

    this->_buildCartesianMesh(corner0, corner1, size);
  }
  this->_dispatch<Dimension>();
}

template CartesianMeshBuilder::CartesianMeshBuilder(const TinyVector<1>&,
                                                    const TinyVector<1>&,
                                                    const TinyVector<1, uint64_t>&);

template CartesianMeshBuilder::CartesianMeshBuilder(const TinyVector<2>&,
                                                    const TinyVector<2>&,
                                                    const TinyVector<2, uint64_t>&);

template CartesianMeshBuilder::CartesianMeshBuilder(const TinyVector<3>&,
                                                    const TinyVector<3>&,
                                                    const TinyVector<3, uint64_t>&);