/*!
 * \file
 * \brief   Main application file for PDS HW2 (MPI)
 *
 * \author
 *    Christos Choutouridis AEM:8997
 *    <cchoutou@ece.auth.gr>
 */

#include <exception>
#include <iostream>
#include <algorithm>
#include <random>

#include "utils.hpp"
#include "config.h"
#include "distsort.hpp"


// Global session data
config_t        config;
MPI_t<>         mpi;
distBuffer_t    Data;
Log             logger;
distStat_t      localStat, remoteStat;

// Mersenne seeded from hw if possible. range: [type_min, type_max]
std::random_device  rd;
std::mt19937        gen(rd());

//! Performance timers for each one of the "costly" functions
Timing Timer_total;
Timing Timer_fullSort;
Timing Timer_exchange;
Timing Timer_minmax;
Timing Timer_elbowSort;

//! Init timing objects for extra rounds
void measurements_init() {
    if (config.perf > 1) {
        Timer_total.init(config.perf);
        Timer_fullSort.init(config.perf);
        Timer_exchange.init(config.perf);
        Timer_minmax.init(config.perf);
        Timer_elbowSort.init(config.perf);
    }
}

//! iterate ot the next round of measurements for all measurement objects
void measurements_next() {
    if (config.perf > 1) {
        Timer_total.next();
        Timer_fullSort.next();
        Timer_exchange.next();
        Timer_minmax.next();
        Timer_elbowSort.next();
    }
}

/*!
 * A small command line argument parser
 * \return  The status of the operation
 */
bool get_options(int argc, char* argv[]){
    bool status =true;

    // iterate over the passed arguments
    for (int i=1 ; i<argc ; ++i) {
        std::string arg(argv[i]);     // get current argument

        if (arg == "-q" || arg == "--array-size") {
            if (i+1 < argc) {
                config.arraySize = 1 << atoi(argv[++i]);
            }
            else {
                status = false;
            }
        }
        else if (arg == "-e" || arg == "--exchange-opt") {
            config.exchangeOpt = true;
        }
        else if (arg == "--pipeline") {
            if (i+1 < argc) {
                auto stages = atoi(argv[++i]);
                if (isPowerOfTwo(stages) && stages <= static_cast<int>(MAX_PIPELINE_SIZE))
                    config.pipeline = stages;
                else
                    status = false;
            }
            else {
                status = false;
            }
        }
        else if (arg == "--validation") {
            config.validation = true;
        }
        else if (arg == "--perf") {
            if (i+1 < argc) {
                config.perf = atoi(argv[++i]);
            }
            else {
                status = false;
            }
        }
        else if (arg == "--ndebug") {
            config.ndebug = true;
        }
        else if (arg == "-v" || arg == "--verbose") {
            config.verbose = true;
        }
        else if (arg == "--version") {
            std::cout << "distbitonic/distbubbletonic - A distributed sort utility\n";
            std::cout << "version: " << version << "\n\n";
            exit(0);
        }
        else if (arg == "-h" || arg == "--help") {
            std::cout << "distbitonic/distbubbletonic - A distributed sort utility\n\n";
            std::cout << "  distbitonic -q <N> [-e] [-p | --pipeline <N>] [--validation] [--perf <N>] [--ndebug] [-v]\n";
            std::cout << "  distbitonic -h\n";
            std::cout << "  distbubbletonic -q <N> [-e] [-p | --pipeline <N>] [--validation] [--perf <N> ] [--ndebug] [-v]\n";
            std::cout << "  distbubbletonic -h\n";
            std::cout << '\n';
            std::cout << "Options:\n\n";
            std::cout << "   -q | --array-size <N>\n";
            std::cout << "      Selects the array size according to size = 2^N\n\n";
            std::cout << "   -e | --exchange-opt\n";
            std::cout << "      Request an MPI data exchange optimization \n\n";
            std::cout << "   -p <N> | --pipeline <N>\n";
            std::cout << "      Request a pipeline of <N> stages for exchange-minmax\n";
            std::cout << "      N must be power of 2 up to " << MAX_PIPELINE_SIZE << "\n\n";
            std::cout << "   --validation\n";
            std::cout << "      Request a full validation at the end, performed by process rank 0\n\n";
            std::cout << "   --perf <N> \n";
            std::cout << "      Enable performance timing measurements and prints, and repeat\n";
            std::cout << "      the sorting <N> times.\n\n";
            std::cout << "   --ndebug\n";
            std::cout << "      Skip debug breakpoint when on debug build.\n\n";
            std::cout << "   -v | --verbose\n";
            std::cout << "      Request a more verbose output to stdout.\n\n";
            std::cout << "   -h | --help\n";
            std::cout << "      Prints this and exit.\n\n";
            std::cout << "   --version\n";
            std::cout << "      Prints version and exit.\n\n";
            std::cout << "Examples:\n\n";
            std::cout << "   mpirun -np 4 distbitonic -q 24\n";
            std::cout << "      Runs distbitonic in 4 MPI processes with 2^24 array points each\n\n";
            std::cout << "   mpirun -np 16 distbubbletonic -q 20\n";
            std::cout << "      Runs distbubbletonic in 16 MPI processes with 2^20 array points each\n\n";

            exit(0);
        }
        else {   // parse error
            std::cout << "Invocation error. Try -h for details.\n";
            status = false;
        }
    }

    return status;
}

/*!
 * A simple validator for the entire distributed process
 *
 * @tparam ShadowedDataT    A Shadowed buffer type with random access iterator.
 *
 * @param data          [ShadowedDataT] The local to MPI process
 * @param Processes     [mpi_id_t]      The total number of MPI processes
 * @param rank          [mpi_id_t]      The current process id
 *
 * @return              [bool]          True if all are sorted and in total ascending order
 */
template<typename ShadowedDataT>
bool validator(ShadowedDataT& data, mpi_id_t Processes, mpi_id_t rank) {
    using value_t = typename ShadowedDataT::value_type;
    bool ret = true;    // Have faith!

    // Local results
    value_t lmin   = data.front();
    value_t lmax   = data.back();
    value_t lsort  = static_cast<value_t>(std::is_sorted(data.begin(), data.end()));

    // Gather min/max/sort to rank 0
    std::vector<value_t> mins(Processes);
    std::vector<value_t> maxes(Processes);
    std::vector<value_t> sorts(Processes);

    MPI_Datatype datatype = MPI_TypeMapper<value_t>::getType();
    MPI_Gather(&lmin,  1, datatype, mins.data(),  1, datatype, 0, MPI_COMM_WORLD);
    MPI_Gather(&lmax,  1, datatype, maxes.data(), 1, datatype, 0, MPI_COMM_WORLD);
    MPI_Gather(&lsort, 1, datatype, sorts.data(), 1, datatype, 0, MPI_COMM_WORLD);

    // Check all results
    if (rank == 0) {
        for (mpi_id_t r = 1; r < Processes; ++r) {
            if (sorts[r] == 0)
                ret = false;
            if (maxes[r - 1] > mins[r])
                ret = false;
        }
    }
    return ret;
}

/*!
 * Initializes the environment, must called from each process
 *
 * @param argc  [int*]      POINTER to main's argc argument
 * @param argv  [char***]   POINTER to main's argv argument
 */
void init(int* argc, char*** argv) {
    // try to read command line
    if (!get_options(*argc, *argv))
        exit(1);

    // Initialize MPI environment
    mpi.init(argc, argv);

    logger << "MPI environment initialized." << " Rank: " << mpi.rank() << " Size: " << mpi.size()
           << logger.endl;

    #if defined DEBUG
    #if defined TESTING
        /*
         * In case of a debug build we will wait here until sleep_wait
         * will reset via debugger. In order to do that the user must attach
         * debugger to all processes. For example:
         *  $> mpirun -np 2 ./<program path>
         *  $> ps aux | grep <program>
         *  $> gdb <program> <PID1>
         *  $> gdb <program> <PID2>
         */
         volatile bool sleep_wait = false;
    #else
        volatile bool sleep_wait = true;
    #endif
        while (sleep_wait && !config.ndebug)
            sleep(1);
    #endif

    // Prepare vector and timing data
    Data.resize(config.arraySize);
    measurements_init();
}

#if !defined TESTING
/*!
 * @return Returns 0, but.... we may throw or exit(0) / exit(1)
 */
int main(int argc, char* argv[]) try {

    // Init everything
    init(&argc, &argv);

    for (size_t it = 0 ; it < config.perf ; ++it) {
        // Initialize local data
        logger << "Initialize local array of " << config.arraySize << " elements" << logger.endl;
        std::uniform_int_distribution<distValue_t > dis(
                std::numeric_limits<distValue_t>::min(),
                std::numeric_limits<distValue_t>::max()
        );
        std::generate(Data.begin(), Data.end(), [&]() { return dis(gen); });
        // Run distributed sort
        if (mpi.rank() == 0)
            logger << "Starting distributed sorting ... ";
        Timer_total.start();
    #if CODE_VERSION == BUBBLETONIC
        distBubbletonic(Data, mpi.size(), mpi.rank());
    #else
        distBitonic(Data, mpi.size(), mpi.rank());
    #endif
        Timer_total.stop();
        measurements_next();
        if (mpi.rank() == 0)
            logger << " Done." << logger.endl;
    }

    // Print-outs and validation
    if (config.perf > 1) {
        Timing::print_duration(Timer_total.median(),    "Total     ", mpi.rank());
        Timing::print_duration(Timer_fullSort.median(), "Full-Sort ", mpi.rank());
        Timing::print_duration(Timer_exchange.median(), "Exchange  ", mpi.rank());
        Timing::print_duration(Timer_minmax.median(),   "Min-Max   ", mpi.rank());
        Timing::print_duration(Timer_elbowSort.median(),"Elbow-Sort", mpi.rank());
    }
    if (config.validation) {
        // If requested, we have the chance to fail!
        if (mpi.rank() == 0)
            std::cout << "[Validation] Results validation ...";
        bool val = validator(Data, mpi.size(), mpi.rank());
        if (mpi.rank() == 0)
            std::cout << ((val) ? "\x1B[32m [PASSED] \x1B[0m\n" : " \x1B[32m [FAILED] \x1B[0m\n");
    }
    mpi.finalize();
    return 0;
}
catch (std::exception& e) {
    //we probably pollute the user's screen. Comment `cerr << ...` if you don't like it.
    std::cerr << "Error: " << e.what() << '\n';
    exit(1);
}

#else

#include <gtest/gtest.h>
#include <exception>

/*!
 * The testing version of our program
 */
GTEST_API_ int main(int argc, char **argv) try {
   testing::InitGoogleTest(&argc, argv);
   return RUN_ALL_TESTS();
}
catch (std::exception& e) {
    std::cout << "Exception: " << e.what() << '\n';
}

#endif