/**
 * \file
 * \brief   Utilities header
 *
 * \author
 *    Christos Choutouridis AEM:8997
 *    <cchoutou@ece.auth.gr>
 */
#ifndef UTILS_HPP_
#define UTILS_HPP_

#include <vector>
#include <iostream>
#include <chrono>
#include <unistd.h>
#include <mpi.h>

#include "config.h"

/*
 * MPI_<type> dispatcher mechanism
 */
template <typename T> struct MPI_TypeMapper { };

template <> struct MPI_TypeMapper<char>          { static MPI_Datatype getType() { return MPI_CHAR; } };
template <> struct MPI_TypeMapper<short>         { static MPI_Datatype getType() { return MPI_SHORT; } };
template <> struct MPI_TypeMapper<int>           { static MPI_Datatype getType() { return MPI_INT; } };
template <> struct MPI_TypeMapper<long>          { static MPI_Datatype getType() { return MPI_LONG; } };
template <> struct MPI_TypeMapper<long long>     { static MPI_Datatype getType() { return MPI_LONG_LONG; } };
template <> struct MPI_TypeMapper<unsigned char> { static MPI_Datatype getType() { return MPI_UNSIGNED_CHAR; } };
template <> struct MPI_TypeMapper<unsigned short>{ static MPI_Datatype getType() { return MPI_UNSIGNED_SHORT; } };
template <> struct MPI_TypeMapper<unsigned int>  { static MPI_Datatype getType() { return MPI_UNSIGNED; } };
template <> struct MPI_TypeMapper<unsigned long> { static MPI_Datatype getType() { return MPI_UNSIGNED_LONG; } };
template <> struct MPI_TypeMapper<unsigned long long> { static MPI_Datatype getType() { return MPI_UNSIGNED_LONG_LONG; } };
template <> struct MPI_TypeMapper<float>         { static MPI_Datatype getType() { return MPI_FLOAT; } };
template <> struct MPI_TypeMapper<double>        { static MPI_Datatype getType() { return MPI_DOUBLE; } };

/*!
 * MPI wrapper type to provide MPI functionality and RAII to MPI as a resource
 *
 * @tparam TID  The MPI type for process id [default: int]
 */
template<typename TID = int>
struct MPI_t {
    using ID_t = TID; // Export TID type (currently int defined by the standard)

    /*!
     * Initializes the MPI environment, must called from each process
     *
     * @param argc  [int*]      POINTER to main's argc argument
     * @param argv  [char***]   POINTER to main's argv argument
     */
    void init(int* argc, char*** argv) {
        // Initialize the MPI environment
        int err;
        if ((err = MPI_Init(argc, argv)) != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Init() - ");
        initialized_ = true;

        // Get the number of processes
        int size_value, rank_value;
        if ((err = MPI_Comm_size(MPI_COMM_WORLD, &size_value)) != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Comm_size() - ");
        if ((err = MPI_Comm_rank(MPI_COMM_WORLD, &rank_value)) != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Comm_rank() - ");
        size_ = static_cast<ID_t>(size_value);
        rank_ = static_cast<ID_t>(rank_value);
        if (size_ > static_cast<ID_t>(MAX_MPI_SIZE))
            throw std::runtime_error(
                    "(MPI) size - Not supported number of nodes [over " + std::to_string(MAX_MPI_SIZE) + "]\n"
            );

        // Get the name of the processor
        char processor_name[MPI_MAX_PROCESSOR_NAME];
        int name_len;
        if ((err = MPI_Get_processor_name(processor_name, &name_len)) != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Get_processor_name() - ");
        name_ = std::string (processor_name, name_len);
    }


    /*!
     * Initiate a data exchange data with partner using non-blocking Isend-Irecv, as part of the
     * sorting network of both bubbletonic or bitonic sorting algorithms.
     *
     * This function matches a transmit and a receive in order for fully exchanged data between
     * current node and partner.
     * @note
     *      This call MUST paired with exchange_wait() for each MPI_t object.
     *      Calling 2 consecutive exchange_start() for the same MPI_t object is undefined.
     *
     * @tparam ValueT   The underlying value type used in buffers
     *
     * @param ldata     [const ValueT*] Pointer to local data to send
     * @param rdata     [ValueT*]       Pointer to buffer to receive data from partner
     * @param count     [size_t]        The number of data to exchange
     * @param partner   [mpi_id_t]      The partner for the exchange
     * @param tag       [int]           The tag to use for the MPI communication
     */
    template<typename ValueT>
    void exchange_start(const ValueT* ldata, ValueT* rdata, size_t count, ID_t partner, int tag) {
        if (tag < 0)
            throw std::runtime_error("(MPI) exchange_data() [tag] - Out of bound");

        MPI_Datatype datatype = MPI_TypeMapper<ValueT>::getType();
        int err;
        err = MPI_Isend(ldata, count, datatype, partner, tag, MPI_COMM_WORLD, &handle_tx);
        if (err != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Isend() - ");
        err = MPI_Irecv(rdata, count, datatype, partner, tag, MPI_COMM_WORLD, &handle_rx);
        if (err != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Irecv() - ");
    }

    /*!
     * Block wait for the completion of the previously called exchange_start()
     *
     * @note
     *      This call MUST paired with exchange_start() for each MPI_t object.
     *      Calling 2 consecutive exchange_wait() for the same MPI_t object is undefined.
     */
    void exchange_wait() {
        MPI_Status status;

        int err;
        if ((err = MPI_Wait(&handle_tx, &status)) != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Wait() [send] - ");

        if ((err = MPI_Wait(&handle_rx, &status)) != MPI_SUCCESS)
            mpi_throw(err, "(MPI) MPI_Wait() [recv] - ");
    }

    // Accessors
    [[nodiscard]] ID_t rank() const noexcept { return rank_; }
    [[nodiscard]] ID_t size() const noexcept { return size_; }
    [[nodiscard]] const std::string& name() const noexcept { return name_; }

    // Mutators
    ID_t rank(ID_t rank) noexcept { return rank_ = rank; }
    ID_t size(ID_t size) noexcept { return size_ = size; }
    std::string& name(const std::string& name) noexcept { return name_ = name; }

    /*!
     * Finalized the MPI
     */
    void finalize() {
        // Finalize the MPI environment
        initialized_ = false;
        MPI_Finalize();
    }

    //! RAII MPI finalization
    ~MPI_t() {
        // Finalize the MPI environment even on unexpected errors
        if (initialized_)
            MPI_Finalize();
    }


    // Local functionality
private:
    /*!
     * Throw exception helper. It bundles the prefix msg with the MPI error string retrieved by
     * MPI API.
     *
     * @param err           The MPI error code
     * @param prefixMsg     The prefix text for the exception error message
     */
    void mpi_throw(int err, const char* prefixMsg) {
        char err_msg[MPI_MAX_ERROR_STRING];
        int msg_len;
        MPI_Error_string(err, err_msg, &msg_len);
        throw std::runtime_error(prefixMsg + std::string (err_msg) + '\n');
    }

private:
    ID_t rank_{};           //!< MPI rank of the process
    ID_t size_{};           //!< MPI total size of the execution
    std::string name_{};    //!< The name of the local machine
    bool initialized_{};    //!< RAII helper flag
    MPI_Request handle_tx{};    //!< MPI async exchange handler for Transmission
    MPI_Request handle_rx{};    //!< MPI async exchange handler for Receptions
};

/*
 * Exported data types
 */
extern MPI_t<>  mpi;
using mpi_id_t = MPI_t<>::ID_t;



/*!
 * @brief A std::vector wrapper with 2 vectors, an active and a shadow.
 *
 * This type exposes the standard vector
 * functionality of the active vector. The shadow can be used when we need to use the vector as mutable
 * data in algorithms that can not support "in-place" editing (like elbow-sort for example)
 *
 * @tparam Value_t  the underlying data type of the vectors
 */
template <typename Value_t>
struct ShadowedVec_t {
    // STL requirements
    using value_type     = Value_t;
    using iterator       = typename std::vector<Value_t>::iterator;
    using const_iterator = typename std::vector<Value_t>::const_iterator;
    using size_type      = typename std::vector<Value_t>::size_type;

    // Default constructor
    ShadowedVec_t() = default;

    // Constructor from an std::vector
    explicit ShadowedVec_t(const std::vector<Value_t>& vec)
            : North(vec), South(), active(north) {
        South.resize(North.size());
    }

    explicit ShadowedVec_t(std::vector<Value_t>&& vec)
            : North(std::move(vec)), South(), active(north) {
        South.resize(North.size());
    }

    // Copy assignment operator
    ShadowedVec_t& operator=(const ShadowedVec_t& other) {
        if (this != &other) { // Avoid self-assignment
            North = other.North;
            South = other.South;
            active = other.active;
        }
        return *this;
    }

    // Move assignment operator
    ShadowedVec_t& operator=(ShadowedVec_t&& other) noexcept {
        if (this != &other) { // Avoid self-assignment
            North = std::move(other.North);
            South = std::move(other.South);
            active = other.active;

            // There is no need to zero out other since it is valid but in a non-defined state
        }
        return *this;
    }

    // Type accessors
    std::vector<Value_t>& getActive() { return (active == north) ? North : South; }
    std::vector<Value_t>& getShadow() { return (active == north) ? South : North; }
    const std::vector<Value_t>& getActive() const { return (active == north) ? North : South; }
    const std::vector<Value_t>& getShadow() const { return (active == north) ? South : North; }

    // Swap vectors
    void switch_active() { active = (active == north) ? south : north; }

    // Dispatch vector functionality to active vector
    Value_t& operator[](size_type index) { return getActive()[index]; }
    const Value_t& operator[](size_type index) const { return getActive()[index]; }

    Value_t& at(size_type index) { return getActive().at(index); }
    const Value_t& at(size_type index) const { return getActive().at(index); }

    void push_back(const Value_t& value) { getActive().push_back(value); }
    void push_back(Value_t&& value)      { getActive().push_back(std::move(value)); }
    void pop_back()                      { getActive().pop_back(); }
    Value_t& front() { return getActive().front(); }
    Value_t& back()  { return getActive().back(); }
    const Value_t& front() const { return getActive().front(); }
    const Value_t& back()  const { return getActive().back(); }

    iterator begin() { return getActive().begin(); }
    const_iterator begin() const { return getActive().begin(); }
    iterator end() { return getActive().end(); }
    const_iterator end() const { return getActive().end(); }

    size_type size() const { return getActive().size(); }
    void resize(size_t new_size) {
        North.resize(new_size);
        South.resize(new_size);
    }

    void reserve(size_t new_capacity) {
        North.reserve(new_capacity);
        South.reserve(new_capacity);
    }
    [[nodiscard]] size_t capacity() const { return getActive().capacity(); }
    [[nodiscard]] bool empty() const { return getActive().empty(); }

    void clear() { getActive().clear(); }
    void swap(std::vector<Value_t>& other) { getActive().swap(other); }

    // Comparisons
    bool operator== (const ShadowedVec_t& other) { return getActive() == other.getActive(); }
    bool operator!= (const ShadowedVec_t& other) { return getActive() != other.getActive(); }
    bool operator== (const std::vector<value_type>& other) { return getActive() == other; }
    bool operator!= (const std::vector<value_type>& other) { return getActive() != other; }

private:
    std::vector<Value_t> North{};       //!< Actual buffer to be used either as active or shadow
    std::vector<Value_t> South{};       //!< Actual buffer to be used either as active or shadow
    enum {
        north, south
    } active{north};                    //!< Flag to select between North and South buffer
};

/*
 * Exported data types
 */
using distBuffer_t = ShadowedVec_t<distValue_t>;
extern distBuffer_t Data;

/*!
 * A Logger for entire program.
 */
struct Log {
    struct Endl {} endl;    //!< a tag object to to use it as a new line request.

    //! We provide logging via << operator
    template<typename T>
    Log &operator<<(T &&t) {
        if (config.verbose) {
            if (line_) {
                std::cout << "[Log]: " << t;
                line_ = false;
            } else
                std::cout << t;
        }
        return *this;
    }

    // overload for special end line handling
    Log &operator<<(Endl e) {
        (void) e;
        if (config.verbose) {
            std::cout << '\n';
            line_ = true;
        }
        return *this;
    }

private:
    bool line_{true};
};

extern Log logger;

/*!
 * A small timing utility based on chrono.
 */
struct Timing {
    using Tpoint = std::chrono::steady_clock::time_point;
    using Tduration = std::chrono::microseconds;
    using microseconds = std::chrono::microseconds;
    using milliseconds = std::chrono::milliseconds;
    using seconds = std::chrono::seconds;

    //! tool to mark the starting point
    Tpoint start() noexcept { return mark_ = std::chrono::steady_clock::now(); }

    //! tool to mark the ending point
    Tpoint stop() noexcept {
        Tpoint now = std::chrono::steady_clock::now();
        duration_ += dt(now, mark_);
        return now;
    }

    //! A duration calculation utility
    static Tduration dt(Tpoint t2, Tpoint t1) noexcept {
        return std::chrono::duration_cast<Tduration>(t2 - t1);
    }

    //! Tool to print the time interval
    void print_duration(const char *what, mpi_id_t rank) noexcept {
        if (std::chrono::duration_cast<microseconds>(duration_).count() < 10000)
            std::cout << "[Timing] (Rank " << rank << ") " << what << ": "
                      << std::to_string(std::chrono::duration_cast<microseconds>(duration_).count()) << " [usec]\n";
        else if (std::chrono::duration_cast<milliseconds>(duration_).count() < 10000)
            std::cout << "[Timing] (Rank " << rank << ") " << what << ": "
                      << std::to_string(std::chrono::duration_cast<milliseconds>(duration_).count()) << " [msec]\n";
        else {
            char stime[26]; // fit ulong
            auto sec  = std::chrono::duration_cast<seconds>(duration_).count();
            auto msec = (std::chrono::duration_cast<milliseconds>(duration_).count() % 1000) / 10;  // keep 2 digit
            std::sprintf(stime, "%ld.%1ld", sec, msec);
            std::cout << "[Timing] (Rank " << rank << ") " << what << ": " << stime << " [sec]\n";
        }

    }

private:
    Tpoint mark_{};
    Tduration duration_{};
};

/*!
 * Utility "high level function"-like macro to forward a function call
 * and accumulate the execution time to the corresponding timing object.
 *
 * @param   Tim     The Timing object [Needs to have methods start() and stop()]
 * @param   Func    The function name
 * @param   ...     The arguments to pass to function (the preprocessor way)
 */
#define timeCall(Tim, Func, ...)    \
    Tim.start();                    \
    Func(__VA_ARGS__);              \
    Tim.stop();                     \


/*!
 * A utility to check if a number is power of two
 *
 * @tparam Integral     The integral type of the number to check
 * @param x             The number to check
 * @return              True if it is power of 2, false otherwise
 */
template <typename Integral>
constexpr inline bool isPowerOfTwo(Integral x) noexcept {
    return (!(x & (x - 1)) && x);
}


#endif /* UTILS_HPP_ */