/**
 * \file    tests.cpp
 * \brief   PDS homework_1 tests
 *
 * \author
 *    Christos Choutouridis AEM:8997
 *    <cchoutou@ece.auth.gr>
 */

#include <gtest/gtest.h>

#include "matrix.hpp"

#include "v0.hpp"
#include "v1.hpp"
#include "utils.hpp"
#include "config.h"


using matrix_t = mtx::Matrix<int>;

extern void loadMtx(MatrixDst& Corpus, MatrixDst& Query);
extern void storeMtx(MatrixIdx& Idx, MatrixDst& Dst);

// =====================================
// C1, Q1
mtx::Matrix<double> C1(10,2, {
   0.8147,    0.1576,
   0.9058,    0.9706,
   0.1270,    0.9572,
   0.9134,    0.4854,
   0.6324,    0.8003,
   0.0975,    0.1419,
   0.2785,    0.4218,
   0.5469,    0.9157,
   0.9575,    0.7922,
   0.9649,    0.9595
});

mtx::Matrix<double> Q1(5,2, {
   0.6557,    0.7577,
   0.0357,    0.7431,
   0.8491,    0.3922,
   0.9340,    0.6555,
   0.6787,    0.1712
});

// =====================================
// C2, Q2
mtx::Matrix<double> C2(16,4, {
   0.7060,    0.4456,    0.5060,    0.6160,
   0.0318,    0.6463,    0.6991,    0.4733,
   0.2769,    0.7094,    0.8909,    0.3517,
   0.0462,    0.7547,    0.9593,    0.8308,
   0.0971,    0.2760,    0.5472,    0.5853,
   0.8235,    0.6797,    0.1386,    0.5497,
   0.6948,    0.6551,    0.1493,    0.9172,
   0.3171,    0.1626,    0.2575,    0.2858,
   0.9502,    0.1190,    0.8407,    0.7572,
   0.0344,    0.4984,    0.2543,    0.7537,
   0.4387,    0.9597,    0.8143,    0.3804,
   0.3816,    0.3404,    0.2435,    0.5678,
   0.7655,    0.5853,    0.9293,    0.0759,
   0.7952,    0.2238,    0.3500,    0.0540,
   0.1869,    0.7513,    0.1966,    0.5308,
   0.4898,    0.2551,    0.2511,    0.7792
});

mtx::Matrix<double> Q2(8,4, {
   0.9340,    0.3112,    0.4505,    0.0782,
   0.1299,    0.5285,    0.0838,    0.4427,
   0.5688,    0.1656,    0.2290,    0.1067,
   0.4694,    0.6020,    0.9133,    0.9619,
   0.0119,    0.2630,    0.1524,    0.0046,
   0.3371,    0.6541,    0.8258,    0.7749,
   0.1622,    0.6892,    0.5383,    0.8173,
   0.7943,    0.7482,    0.9961,    0.8687
});



/*
 * ==========================================
 * pdist2
 */
TEST(Tv0_UT, pdist2_test1) {

   mtx::Matrix<double> D1_exp(10, 5, {
      0.6208,    0.9745,    0.2371,    0.5120,    0.1367,
      0.3284,    0.8993,    0.5811,    0.3164,    0.8310,
      0.5651,    0.2327,    0.9169,    0.8616,    0.9603,
      0.3749,    0.9147,    0.1132,    0.1713,    0.3921,
      0.0485,    0.5994,    0.4621,    0.3346,    0.6308,
      0.8312,    0.6044,    0.7922,    0.9815,    0.5819,
      0.5052,    0.4028,    0.5714,    0.6959,    0.4722,
      0.1919,    0.5395,    0.6045,    0.4665,    0.7561,
      0.3037,    0.9231,    0.4144,    0.1387,    0.6807,
      0.3692,    0.9540,    0.5790,    0.3056,    0.8386
   });

   mtx::Matrix<double> D (10,5);

   v0::pdist2(C1, Q1, D);

   for (size_t i = 0 ; i< D.rows() ; ++i)
      for (size_t j = 0 ; j<D.columns() ; ++j) {
         EXPECT_EQ (D1_exp.get(i ,j) + 0.01 > D(i, j), true);
         EXPECT_EQ (D1_exp.get(i ,j) - 0.01 < D(i, j), true);
      }
}

TEST(Tv0_UT, pdist2_test2) {

   mtx::Matrix<double> D2_exp(16, 8, {
      0.6020,    0.7396,    0.6583,    0.6050,    1.0070,    0.5542,    0.6298,    0.6352,
      1.0696,    0.6348,    0.9353,    0.6914,    0.8160,    0.4475,    0.4037,    0.9145,
      0.9268,    0.8450,    0.9376,    0.6492,    0.9671,    0.4360,    0.5956,    0.7400,
      1.3455,    0.9876,    1.2953,    0.4709,    1.2557,    0.3402,    0.4417,    0.7500,
      0.9839,    0.5476,    0.7517,    0.7216,    0.7074,    0.5605,    0.4784,    0.9954,
      0.6839,    0.7200,    0.7305,    0.9495,    1.0628,    0.8718,    0.8178,    0.9179,
      0.9850,    0.7514,    0.9585,    0.7996,    1.2054,    0.7784,    0.6680,    0.8591,
      0.6950,    0.4730,    0.3103,    1.0504,    0.4397,    0.8967,    0.8140,    1.2066,
      0.8065,    1.2298,    0.9722,    0.7153,    1.3933,    0.8141,    1.0204,    0.6758,
      1.1572,    0.3686,    0.9031,    0.8232,    0.7921,    0.6656,    0.3708,    1.0970,
      0.9432,    0.9049,    1.0320,    0.6905,    1.1167,    0.5094,    0.6455,    0.6653,
      0.7672,    0.3740,    0.5277,    0.8247,    0.6842,    0.6945,    0.5648,    0.9968,
      0.5768,    1.1210,    0.8403,    0.9345,    1.1316,    0.8292,    1.0380,    0.8127,
      0.1939,    0.8703,    0.2684,    1.1794,    0.8103,    1.0683,    1.1115,    1.1646,
      1.0106,    0.2708,    0.8184,    0.8954,    0.7402,    0.6982,    0.4509,    1.0594,
      0.8554,    0.5878,    0.6834,    0.7699,    0.9155,    0.7161,    0.6162,    0.9481
   });

   mtx::Matrix<double> D (16,8);

   v0::pdist2(C2, Q2, D);

   for (size_t i = 0 ; i< D.rows() ; ++i)
      for (size_t j = 0 ; j<D.columns() ; ++j) {
         EXPECT_EQ (D2_exp.get(i ,j) + 0.01 > D(i, j), true);
         EXPECT_EQ (D2_exp.get(i ,j) - 0.01 < D(i, j), true);
      }
}


TEST(Tv0_UT, pdist2_test3) {

   mtx::Matrix<double> D2_exp(16, 16, {
            0,  0.7433,  0.6868,  0.8846,  0.6342,  0.4561,  0.5118,  0.6341,  0.5461,  0.7322,  0.6974,  0.4330,  0.7028,  0.6303,  0.6826,  0.4179,
       0.7433,       0,  0.3400,  0.4555,  0.4207,  0.9736,  0.9690,  0.7386,  1.1055,  0.5462,  0.5345,  0.6576,  0.8677,  1.0291,  0.5393,  0.8106,
       0.6868,  0.3400,       0,  0.5380,  0.6268,  0.9512,  1.0234,  0.8403,  0.9843,  0.8187,  0.3091,  0.7829,  0.5759,  0.9411,  0.7239,  0.9186,
       0.8846,  0.4555,  0.5380,       0,  0.6796,  1.1672,  1.0460,  1.1016,  1.1139,  0.7542,  0.6480,  0.9304,  1.0568,  1.3482,  0.8316,  0.9750,
       0.6342,  0.4207,  0.6268,  0.6796,       0,  0.9267,  0.8772,  0.4847,  0.9317,  0.4093,  0.8351,  0.4215,  0.9736,  0.9007,  0.5999,  0.5291,
       0.4561,  0.9736,  0.9512,  1.1672,  0.9267,       0,  0.3903,  0.7795,  0.9308,  0.8429,  0.8436,  0.5672,  0.9284,  0.7064,  0.6435,  0.5975,
       0.5118,  0.9690,  1.0234,  1.0460,  0.8772,  0.3903,       0,  0.8920,  0.9253,  0.7060,  0.9427,  0.5728,  1.1515,  0.9907,  0.6471,  0.4811,
       0.6341,  0.7386,  0.8403,  1.1016,  0.4847,  0.7795,  0.8920,       0,  0.9824,  0.6416,  0.9844,  0.3398,  0.9355,  0.5428,  0.6536,  0.5309,
       0.5461,  1.1055,  0.9843,  1.1139,  0.9317,  0.9308,  0.9253,  0.9824,       0,  1.1517,  1.0541,  0.8746,  0.8506,  0.8777,  1.2036,  0.7607,
       0.7322,  0.5462,  0.8187,  0.7542,  0.4093,  0.8429,  0.7060,  0.6416,  1.1517,       0,  0.9106,  0.4245,  1.2071,  1.0738,  0.3745,  0.5170,
       0.6974,  0.5345,  0.3091,  0.6480,  0.8351,  0.8436,  0.9427,  0.9844,  1.0541,  0.9106,       0,  0.8647,  0.5941,  0.9954,  0.7148,  0.9876,
       0.4330,  0.6576,  0.7829,  0.9304,  0.4215,  0.5672,  0.5728,  0.3398,  0.8746,  0.4245,  0.8647,       0,  0.9590,  0.6782,  0.4586,  0.2525,
       0.7028,  0.8677,  0.5759,  1.0568,  0.9736,  0.9284,  1.1515,  0.9355,  0.8506,  1.2071,  0.5941,  0.9590,       0,  0.6838,  1.0517,  1.0675,
       0.6303,  1.0291,  0.9411,  1.3482,  0.9007,  0.7064,  0.9907,  0.5428,  0.8777,  1.0738,  0.9954,  0.6782,  0.6838,       0,  0.9482,  0.7937,
       0.6826,  0.5393,  0.7239,  0.8316,  0.5999,  0.6435,  0.6471,  0.6536,  1.2036,  0.3745,  0.7148,  0.4586,  1.0517,  0.9482,       0,  0.6345,
       0.4179,  0.8106,  0.9186,  0.9750,  0.5291,  0.5975,  0.4811,  0.5309,  0.7607,  0.5170,  0.9876,  0.2525,  1.0675,  0.7937,  0.6345,       0
   });

   mtx::Matrix<double> D (16,16);

   v0::pdist2(C2, C2, D);

   for (size_t i = 0 ; i< D.rows() ; ++i)
      for (size_t j = 0 ; j<D.columns() ; ++j) {
         EXPECT_EQ (D2_exp.get(i ,j) + 0.01 > D(i, j), true);
         EXPECT_EQ (D2_exp.get(i ,j) - 0.01 < D(i, j), true);
      }
}


/*
 * ==========================================
 * v0::knn
 */
TEST(Tv0_UT, knn_v0_test1) {
   size_t k = 3;
   mtx::Matrix<uint32_t> Idx_exp(5, k, {
      5,     8,     9,
      3,     7,     8,
      4,     1,     9,
      9,     4,     10,
      1,     4,     7
   });

   mtx::Matrix<double> Dst_exp(5, k, {
      0.0485,    0.1919,    0.3037,
      0.2327,    0.4028,    0.5395,
      0.1132,    0.2371,    0.4144,
      0.1387,    0.1713,    0.3056,
      0.1367,    0.3921,    0.4722
   });

   mtx::Matrix<uint32_t> Idx(5, k);
   mtx::Matrix<double>   Dst(5, k);

   v0::knnsearch(C1, Q1, 0, k, 0, Idx, Dst);


   for (size_t i = 0 ; i< Idx.rows() ; ++i)
      for (size_t j = 0 ; j<Idx.columns() ; ++j) {
         EXPECT_EQ (Idx_exp(i ,j) == Idx(i, j) + 1, true);  // matlab starts from 1
         EXPECT_EQ (Dst_exp.get(i ,j) + 0.01 > Dst(i, j), true);
         EXPECT_EQ (Dst_exp.get(i ,j) - 0.01 < Dst(i, j), true);
      }

}

TEST(Tv0_UT, knn_v0_test2) {
   size_t k = 3;
   mtx::Matrix<uint32_t> Idx_exp(8, k, {
      14,    13,     1,
      15,    10,    12,
      14,     8,    12,
       4,     1,     3,
       8,    12,     5,
       4,     3,     2,
      10,     2,     4,
       1,    11,     9
   });

   mtx::Matrix<double> Dst_exp(8, k, {
      0.1939,    0.5768,    0.6020,
      0.2708,    0.3686,    0.3740,
      0.2684,    0.3103,    0.5277,
      0.4709,    0.6050,    0.6492,
      0.4397,    0.6842,    0.7074,
      0.3402,    0.4360,    0.4475,
      0.3708,    0.4037,    0.4417,
      0.6352,    0.6653,    0.6758
   });

   mtx::Matrix<uint32_t> Idx(8, k);
   mtx::Matrix<double>   Dst(8, k);

   v0::knnsearch(C2, Q2, 0, k, 0, Idx, Dst);


   for (size_t i = 0 ; i< Idx.rows() ; ++i)
      for (size_t j = 0 ; j<Idx.columns() ; ++j) {
         EXPECT_EQ (Idx_exp(i ,j) == Idx(i, j) + 1, true);  // matlab starts from 1
         EXPECT_EQ (Dst_exp.get(i ,j) + 0.01 > Dst(i, j), true);
         EXPECT_EQ (Dst_exp.get(i ,j) - 0.01 < Dst(i, j), true);
      }

}

/*
 * ==========================================
 * v1::knn
 */
TEST(Tv1_UT, knn_v1_1slice) {
   size_t k = 3;
   mtx::Matrix<uint32_t> Idx_exp(8, k, {
      14,    13,     1,
      15,    10,    12,
      14,     8,    12,
       4,     1,     3,
       8,    12,     5,
       4,     3,     2,
      10,     2,     4,
       1,    11,     9
   });

   mtx::Matrix<double> Dst_exp(8, k, {
      0.1939,    0.5768,    0.6020,
      0.2708,    0.3686,    0.3740,
      0.2684,    0.3103,    0.5277,
      0.4709,    0.6050,    0.6492,
      0.4397,    0.6842,    0.7074,
      0.3402,    0.4360,    0.4475,
      0.3708,    0.4037,    0.4417,
      0.6352,    0.6653,    0.6758
   });

   mtx::Matrix<uint32_t> Idx(8, k);
   mtx::Matrix<double>   Dst(8, k);

   v1::knnsearch(C2, Q2, 1, k, k, Idx, Dst);


   for (size_t i = 0 ; i< Idx.rows() ; ++i)
      for (size_t j = 0 ; j<Idx.columns() ; ++j) {
         EXPECT_EQ (Idx_exp(i ,j) == Idx(i, j) + 1, true);  // matlab starts from 1
         EXPECT_EQ (Dst_exp.get(i ,j) + 0.01 > Dst(i, j), true);
         EXPECT_EQ (Dst_exp.get(i ,j) - 0.01 < Dst(i, j), true);
      }

}

TEST(Tv1_UT, knn_v1_2slice) {
   size_t k = 3;
   mtx::Matrix<uint32_t> Idx_exp(8, k, {
      14,    13,     1,
      15,    10,    12,
      14,     8,    12,
       4,     1,     3,
       8,    12,     5,
       4,     3,     2,
      10,     2,     4,
       1,    11,     9
   });

   mtx::Matrix<double> Dst_exp(8, k, {
      0.1939,    0.5768,    0.6020,
      0.2708,    0.3686,    0.3740,
      0.2684,    0.3103,    0.5277,
      0.4709,    0.6050,    0.6492,
      0.4397,    0.6842,    0.7074,
      0.3402,    0.4360,    0.4475,
      0.3708,    0.4037,    0.4417,
      0.6352,    0.6653,    0.6758
   });

   mtx::Matrix<uint32_t> Idx(8, k);
   mtx::Matrix<double>   Dst(8, k);

   v1::knnsearch(C2, Q2, 2, k, k, Idx, Dst);


   for (size_t i = 0 ; i< Idx.rows() ; ++i)
      for (size_t j = 0 ; j<Idx.columns() ; ++j) {
         EXPECT_EQ (Idx_exp(i ,j) == Idx(i, j) + 1, true);  // matlab starts from 1
         EXPECT_EQ (Dst_exp.get(i ,j) + 0.01 > Dst(i, j), true);
         EXPECT_EQ (Dst_exp.get(i ,j) - 0.01 < Dst(i, j), true);
      }

}

// all-to-all
TEST(Tv1_UT, knn_v1_4slice) {
   size_t k = 3;
   mtx::Matrix<uint32_t> Idx_exp(16, k, {
      1,    16,    12,
      2,     3,     5,
      3,    11,     2,
      4,     2,     3,
      5,    10,     2,
      6,     7,     1,
      7,     6,    16,
      8,    12,     5,
      9,     1,    16,
     10,    15,     5,
     11,     3,     2,
     12,    16,     8,
     13,     3,    11,
     14,     8,     1,
     15,    10,    12,
     16,    12,     1
   });

   mtx::Matrix<double> Dst_exp(16, k, {
      0,    0.4179,    0.4331,
      0,    0.3401,    0.4207,
      0,    0.3092,    0.3401,
      0,    0.4555,    0.5381,
      0,    0.4093,    0.4207,
      0,    0.3903,    0.4560,
      0,    0.3903,    0.4811,
      0,    0.3398,    0.4846,
      0,    0.5461,    0.7607,
      0,    0.3745,    0.4093,
      0,    0.3092,    0.5345,
      0,    0.2524,    0.3398,
      0,    0.5759,    0.5941,
      0,    0.5428,    0.6304,
      0,    0.3745,    0.4586,
      0,    0.2524,    0.4179
   });

   mtx::Matrix<uint32_t> Idx(16, k);
   mtx::Matrix<double>   Dst(16, k);

   v1::knnsearch(C2, C2, 4, k, k, Idx, Dst);


   for (size_t i = 0 ; i< Idx.rows() ; ++i)
      for (size_t j = 0 ; j<Idx.columns() ; ++j) {
         EXPECT_EQ (Idx_exp(i ,j) == Idx(i, j) + 1, true);  // matlab starts from 1
         EXPECT_EQ (Dst_exp.get(i ,j) + 0.01 > Dst(i, j), true);
         EXPECT_EQ (Dst_exp.get(i ,j) - 0.01 < Dst(i, j), true);
      }

}



/*
 * ============== Live hdf5 tests ===============
 *
 * In order to run these test we need the followin hdf5 files in ./mtx directory:
 *
 * - fasion-mnist-784-euclidean.hdf5
 * - mnist-784-euclidean.hdf5
 * - sift-128-euclidean.hdf5
 * - gist-960-euclidean.hdf5
 *
 */

TEST(Tlive_UT, knn_v0_sift_test) {
   // Instantiate matrixes
   MatrixDst Corpus;
   MatrixDst Query;
   MatrixIdx Idx;
   MatrixDst Dst;

   // setup environment
   session.corpusMtxFile = "mtx/sift-128-euclidean.hdf5";
   session.corpusDataSet = "/test";
   session.queryMtx = false;
   session.k = 100;
   size_t m = session.k;
   session.timing = true;
   session.outMtxFile = "test/knn_v0.hdf5";


   loadMtx(Corpus, Query);

   // Prepare output memory (There is no Query, so from Corpus
   Idx.resize(Corpus.rows(), session.k);
   Dst.resize(Corpus.rows(), session.k);

   v0::knnsearch(Corpus, Corpus, 0, session.k, m, Idx, Dst);
   storeMtx(Idx, Dst);
   EXPECT_EQ(true, true);
}


TEST(Tlive_UT, knn_v1_sift_test_1slice) {
   // Instantiate matrixes
   MatrixDst Corpus;
   MatrixDst Query;
   MatrixIdx Idx;
   MatrixDst Dst;

   // setup environment
   session.corpusMtxFile = "mtx/sift-128-euclidean.hdf5";
   session.corpusDataSet = "/test";
   session.queryMtx = false;
   session.k = 100;
   size_t m = session.k;
   session.timing = true;
   session.outMtxFile = "test/knn_v1ser.hdf5";


   loadMtx(Corpus, Query);

   // Prepare output memory (There is no Query, so from Corpus
   Idx.resize(Corpus.rows(), session.k);
   Dst.resize(Corpus.rows(), session.k);

   v1::knnsearch(Corpus, Corpus, 0, session.k, m, Idx, Dst);
   storeMtx(Idx, Dst);
   EXPECT_EQ(true, true);
}

TEST(Tlive_UT, knn_v1_sift_test_2slice) {
   // Instantiate matrixes
   MatrixDst Corpus;
   MatrixDst Query;
   MatrixIdx Idx;
   MatrixDst Dst;

   // setup environment
   session.corpusMtxFile = "mtx/sift-128-euclidean.hdf5";
   session.corpusDataSet = "/test";
   session.queryMtx = false;
   session.k = 100;
   size_t m = session.k;
   session.timing = true;
   session.outMtxFile = "test/knn_v1ser.hdf5";


   loadMtx(Corpus, Query);

   // Prepare output memory (There is no Query, so from Corpus
   Idx.resize(Corpus.rows(), session.k);
   Dst.resize(Corpus.rows(), session.k);

   v1::knnsearch(Corpus, Corpus, 2, session.k, m, Idx, Dst);
   storeMtx(Idx, Dst);
   EXPECT_EQ(true, true);
}

TEST(Tlive_UT, knn_v1_sift_test_4slice) {
   // Instantiate matrixes
   MatrixDst Corpus;
   MatrixDst Query;
   MatrixIdx Idx;
   MatrixDst Dst;

   // setup environment
   session.corpusMtxFile = "mtx/sift-128-euclidean.hdf5";
   session.corpusDataSet = "/test";
   session.queryMtx = false;
   session.k = 100;
   size_t m = session.k;
   session.timing = true;
   session.outMtxFile = "test/knn_v1ser.hdf5";


   loadMtx(Corpus, Query);

   // Prepare output memory (There is no Query, so from Corpus
   Idx.resize(Corpus.rows(), session.k);
   Dst.resize(Corpus.rows(), session.k);

   v1::knnsearch(Corpus, Corpus, 4, session.k, m, Idx, Dst);
   storeMtx(Idx, Dst);
   EXPECT_EQ(true, true);
}