/** * \file v0.hpp * \brief * * \author * Christos Choutouridis AEM:8997 * */ #ifndef V0_HPP_ #define V0_HPP_ #include #include #include #include #include #include namespace v0 { /*! * Function to compute squared Euclidean distances * * \fn void pdist2(const double*, const double*, double*, int, int, int) * \param X m x d matrix (Column major) * \param Y n x d matrix (Column major) * \param D2 m x n matrix to store distances (Column major) * \param m number of rows in X * \param n number of rows in Y * \param d number of columns in both X and Y */ template void pdist2(const mtx::Matrix& X, const mtx::Matrix& Y, mtx::Matrix& D2) { int M = X.rows(); int N = Y.rows(); int d = X.columns(); // Compute the squared norms of each row in X and Y std::vector X_norms(M), Y_norms(N); for (int i = 0; i < M ; ++i) { X_norms[i] = cblas_ddot(d, X.data() + i * d, 1, X.data() + i * d, 1); } for (int j = 0; j < N ; ++j) { Y_norms[j] = cblas_ddot(d, Y.data() + j * d, 1, Y.data() + j * d, 1); } // Compute -2 * X * Y' cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, d, -2.0, X.data(), d, Y.data(), d, 0.0, D2.data(), N); // Step 3: Add the squared norms to each entry in D2 for (int i = 0; i < M ; ++i) { for (int j = 0; j < N; ++j) { D2.set(D2.get(i, j) + X_norms[i] + Y_norms[j], i, j); //D2.set(std::max(D2.get(i, j), 0.0), i, j); // Ensure non-negative D2.set(std::sqrt(D2.get(i, j)), i, j); // Take the square root of each } } } template void quickselect(std::vector>& vec, int k) { std::nth_element( vec.begin(), vec.begin() + k, vec.end(), [](const std::pair& a, const std::pair& b) { return a.first < b.first; }); vec.resize(k); // Keep only the k smallest elements } /*! * \param C Is a MxD matrix (Corpus) * \param Q Is a NxD matrix (Query) * \param k The number of nearest neighbors needed * \param idx Is the Nxk matrix with the k indexes of the C points, that are * neighbors of the nth point of Q * \param dst Is the Nxk matrix with the k distances to the C points of the nth * point of Q */ template void knnsearch(const mtx::Matrix& C, const mtx::Matrix& Q, int k, mtx::Matrix& idx, mtx::Matrix& dst) { int M = C.rows(); int N = Q.rows(); mtx::Matrix D(M, N); pdist2(C, Q, D); idx.resize(N, k); dst.resize(N, k); for (int j = 0; j < N; ++j) { // Create a vector of pairs (distance, index) for the j-th query std::vector> dst_idx(M); for (int i = 0; i < M; ++i) { dst_idx[i] = {D.data()[i * N + j], i}; } // Find the k smallest distances using quickSelectKSmallest quickselect(dst_idx, k); // Sort the k smallest results by distance for consistency std::sort(dst_idx.begin(), dst_idx.end()); // Store the indices and distances for (int i = 0; i < k; ++i) { idx(j, i) = dst_idx[i].second; dst(j, i) = dst_idx[i].first; } } } } #endif /* V0_HPP_ */