|
|
@@ -1,15 +1,107 @@ |
|
|
|
#include <iostream> |
|
|
|
#include <cblas.h> |
|
|
|
#include <cmath> |
|
|
|
#include <vector> |
|
|
|
#include <algorithm> |
|
|
|
#include <queue> |
|
|
|
|
|
|
|
/*! |
|
|
|
* Function to compute squared Euclidean distances |
|
|
|
* |
|
|
|
* \fn void pdist2(const double*, const double*, double*, int, int, int) |
|
|
|
* \param X m x d matrix |
|
|
|
* \param Y n x d matrix |
|
|
|
* \param D2 m x n matrix to store distances |
|
|
|
* \param m number of rows in X |
|
|
|
* \param n number of rows in Y |
|
|
|
* \param d number of columns in both X and Y |
|
|
|
*/ |
|
|
|
void pdist2(const double* X, const double* Y, double* D2, int m, int n, int d){ |
|
|
|
// Compute the squared norms of each row in X and Y |
|
|
|
std::vector<double> X_norms(m), Y_norms(n); |
|
|
|
for (int i = 0; i < m; ++i) { |
|
|
|
X_norms[i] = cblas_ddot(d, X + i * d, 1, X + i * d, 1); |
|
|
|
} |
|
|
|
for (int j = 0; j < n; ++j) { |
|
|
|
Y_norms[j] = cblas_ddot(d, Y + j * d, 1, Y + j * d, 1); |
|
|
|
} |
|
|
|
|
|
|
|
// Compute -2 * X * Y' |
|
|
|
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, d, -2.0, X, d, Y, d, 0.0, D2, n); |
|
|
|
|
|
|
|
// Step 3: Add the squared norms to each entry in D2 |
|
|
|
for (int i = 0; i < m; ++i) { |
|
|
|
for (int j = 0; j < n; ++j) { |
|
|
|
D2[i * n + j] += X_norms[i] + Y_norms[j]; |
|
|
|
D2[i * n + j] = std::max(D2[i * n + j], 0.0); // Ensure non-negative |
|
|
|
D2[i * n + j] = std::sqrt(D2[i * n + j]); // Take the square root of each |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void quickselect(std::vector<std::pair<double, int>>& vec, int k) { |
|
|
|
std::nth_element( |
|
|
|
vec.begin(), |
|
|
|
vec.begin() + k, |
|
|
|
vec.end(), |
|
|
|
[](const std::pair<double, int>& a, const std::pair<double, int>& b) { |
|
|
|
return a.first < b.first; |
|
|
|
}); |
|
|
|
vec.resize(k); // Keep only the k smallest elements |
|
|
|
} |
|
|
|
|
|
|
|
// K-nearest neighbor search function |
|
|
|
void knnsearch(const double* C, const double* Q, int m, int n, int d, int k, |
|
|
|
std::vector<std::vector<int>>& idx, std::vector<std::vector<double>>& dst) { |
|
|
|
std::vector<double> D(m * n); |
|
|
|
pdist2(C, Q, D.data(), m, n, d); |
|
|
|
|
|
|
|
idx.resize(n, std::vector<int>(k)); |
|
|
|
dst.resize(n, std::vector<double>(k)); |
|
|
|
|
|
|
|
for (int j = 0; j < n; ++j) { |
|
|
|
// Create a vector of pairs (distance, index) for the j-th query |
|
|
|
std::vector<std::pair<double, int>> dst_idx(m); |
|
|
|
for (int i = 0; i < m; ++i) { |
|
|
|
dst_idx[i] = {D[i * n + j], i}; |
|
|
|
} |
|
|
|
|
|
|
|
// Find the k smallest distances using quickSelectKSmallest |
|
|
|
quickselect(dst_idx, k); |
|
|
|
|
|
|
|
// Sort the k smallest results by distance for consistency |
|
|
|
std::sort(dst_idx.begin(), dst_idx.end()); |
|
|
|
|
|
|
|
// Store the indices and distances |
|
|
|
for (int i = 0; i < k; ++i) { |
|
|
|
idx[j][i] = dst_idx[i].second; |
|
|
|
dst[j][i] = dst_idx[i].first; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int main(){ |
|
|
|
|
|
|
|
double A[6] = {1.0,2.0,1.0,-3.0,4.0,-1.0}; |
|
|
|
double B[6] = {1.0,2.0,1.0,-3.0,4.0,-1.0}; |
|
|
|
double C[9] = {.5,.5,.5,.5,.5,.5,.5,.5,.5}; |
|
|
|
cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans,3,3,2,1,A, 3, B, 3,2,C,3); |
|
|
|
int m = 5; // Number of points in C (corpus) |
|
|
|
int n = 3; // Number of points in Q (query) |
|
|
|
int d = 2; // Dimensions |
|
|
|
int k = 2; // Number of nearest neighbors to find |
|
|
|
|
|
|
|
double C[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; // m x d matrix |
|
|
|
double Q[] = {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}; // n x d matrix |
|
|
|
|
|
|
|
std::vector<std::vector<int>> idx; |
|
|
|
std::vector<std::vector<double>> dst; |
|
|
|
|
|
|
|
knnsearch(C, Q, m, n, d, k, idx, dst); |
|
|
|
|
|
|
|
// Print results |
|
|
|
for (int i = 0; i < n; ++i) { |
|
|
|
std::cout << "Query point " << i << ":\n"; |
|
|
|
for (int j = 0; j < k; ++j) { |
|
|
|
std::cout << " Neighbor " << j <<": Index = " << idx[i][j] <<", Distance = " << dst[i][j] << '\n'; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (int i=0 ; i<9 ; ++i) |
|
|
|
std::cout << C[i] << ' '; |
|
|
|
std::cout << '\n'; |
|
|
|
return 0; |
|
|
|
return 0; |
|
|
|
} |