#include "mpi.h"

#include <numeric>
#include <format>
#include <iostream>
#include <ranges>
#include <vector>

class HaloReceiver
{
public:
    HaloReceiver(std::vector< int > offsets, std::vector< int > ranks)
        : offsets_(std::move(offsets)), ranks_(std::move(ranks)), requests_(ranks_.size())
    {}

    void postReceives(std::span< double > data)
    {
        for (auto&& [ofs, dest, req] : std::views::zip(std::views::adjacent< 2 >(offsets_), ranks_, requests_))
        {
            const auto [begin, end] = ofs;
            const auto recv_range   = data.subspan(begin, end - begin);
            MPI_Irecv(recv_range.data(), recv_range.size(), MPI_DOUBLE, dest, 0, MPI_COMM_WORLD, &req);
        }
    }

    void wait() { MPI_Waitall(requests_.size(), requests_.data(), MPI_STATUSES_IGNORE); }

private:
    std::vector< int >         offsets_, ranks_;
    std::vector< MPI_Request > requests_;
};

class HaloSender
{
public:
    struct RankData
    {
        std::vector< int > inds;
        int                rank;
    };

    HaloSender(std::vector< RankData > rank_data)
        : rank_data_(std::move(rank_data)), offsets_(rank_data_.size() + 1), requests_(rank_data_.size())
    {
        std::transform_inclusive_scan(
            rank_data_.begin(), rank_data_.end(), std::next(offsets_.begin()), std::plus{}, [](const RankData& rd) {
                return rd.inds.size();
            });
        recv_buf_.resize(offsets_.back());
    }

    void postSends(std::span< const double > data)
    {
        // Pack
        for (auto&& [ofs, rd] : std::views::zip(std::views::adjacent< 2 >(offsets_), rank_data_))
        {
            const auto [begin, end] = ofs;
            for (auto&& [buf_ind, data_ind] : std::views::zip(std::views::iota(begin, end), rd.inds))
                recv_buf_[buf_ind] = data[data_ind];
        }

        // Send
        for (auto&& [ofs, rd, req] : std::views::zip(std::views::adjacent< 2 >(offsets_), rank_data_, requests_))
        {
            const auto [begin, end] = ofs;
            const auto send_range   = std::span{recv_buf_}.subspan(begin, end - begin);
            MPI_Isend(send_range.data(), send_range.size(), MPI_DOUBLE, rd.rank, 0, MPI_COMM_WORLD, &req);
        }
    }

    void wait() { MPI_Waitall(requests_.size(), requests_.data(), MPI_STATUSES_IGNORE); }

private:
    std::vector< RankData >    rank_data_;
    std::vector< int >         offsets_;
    std::vector< double >      recv_buf_;
    std::vector< MPI_Request > requests_;
};

int wrapAround(int value, int size)
{
    if (value < 0)
        return size + value;
    if (value >= size)
        return value - size;
    return value;
}

void printData(int my_rank, int comm_size, std::span< const double > data)
{
    for (int printer_rank = 0; printer_rank != comm_size; ++printer_rank)
    {
        if (my_rank == printer_rank)
        {
            std::string message = std::format("Rank {}: ", my_rank);
            for (double v : data)
                message += std::format("{}, ", v);
            std::cout << std::format("{}", message) << std::endl;
        }
        MPI_Barrier(MPI_COMM_WORLD);
    }
    if (my_rank == 0)
        std::cout << "---" << std::endl;
    MPI_Barrier(MPI_COMM_WORLD);
}

int main(int argc, char* argv[])
{
    MPI_Init(&argc, &argv);
    int my_rank, comm_size;
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &comm_size);

    const int in_nbr_rank  = wrapAround(my_rank - 1, comm_size);
    const int out_nbr_rank = wrapAround(my_rank + 1, comm_size);

    constexpr auto num_owned_local = 4uz;
    constexpr auto num_shared      = 1uz;
    auto           receiver        = HaloReceiver({num_owned_local, num_owned_local + num_shared}, {out_nbr_rank});
    auto           sender          = HaloSender({{{0}, in_nbr_rank}});

    auto x = std::vector< double >(num_owned_local + num_shared);
    for (auto&& [i, v] : x | std::views::take(num_owned_local) | std::views::enumerate)
        v = my_rank * 100 + i + 1;

    printData(my_rank, comm_size, x);
    sender.postSends(x);
    receiver.postReceives(x);
    receiver.wait();
    printData(my_rank, comm_size, x);
    sender.wait();

    MPI_Finalize();
}
