130 lines
3.9 KiB
C++
130 lines
3.9 KiB
C++
/**
|
|
* \file v0.hpp
|
|
* \brief
|
|
*
|
|
* \author
|
|
* Christos Choutouridis AEM:8997
|
|
* <cchoutou@ece.auth.gr>
|
|
*/
|
|
#ifndef V0_HPP_
|
|
#define V0_HPP_
|
|
|
|
#include <cblas.h>
|
|
#include <cmath>
|
|
#include <vector>
|
|
#include <algorithm>
|
|
|
|
#include "matrix.hpp"
|
|
#include "config.h"
|
|
|
|
namespace v0 {
|
|
|
|
/*!
|
|
* Function to compute squared Euclidean distances
|
|
*
|
|
* \fn void pdist2(const double*, const double*, double*, int, int, int)
|
|
* \param X m x d matrix (Column major)
|
|
* \param Y n x d matrix (Column major)
|
|
* \param D2 m x n matrix to store distances (Column major)
|
|
* \param m number of rows in X
|
|
* \param n number of rows in Y
|
|
* \param d number of columns in both X and Y
|
|
*/
|
|
template<typename Matrix>
|
|
void pdist2(const Matrix& X, const Matrix& Y, Matrix& D2) {
|
|
using DataType = typename Matrix::dataType;
|
|
|
|
int M = X.rows();
|
|
int N = Y.rows();
|
|
int d = X.columns();
|
|
|
|
// Compute the squared norms of each row in X and Y
|
|
std::vector<DataType> X_norms(M), Y_norms(N);
|
|
for (int i = 0; i < M ; ++i) {
|
|
X_norms[i] = cblas_ddot(d, X.data() + i * d, 1, X.data() + i * d, 1);
|
|
}
|
|
for (int j = 0; j < N ; ++j) {
|
|
Y_norms[j] = cblas_ddot(d, Y.data() + j * d, 1, Y.data() + j * d, 1);
|
|
}
|
|
|
|
// Compute -2 * X * Y'
|
|
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, M, N, d, -2.0, X.data(), d, Y.data(), d, 0.0, D2.data(), N);
|
|
|
|
// Step 3: Add the squared norms to each entry in D2
|
|
for (int i = 0; i < M ; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
D2.set(D2.get(i, j) + X_norms[i] + Y_norms[j], i, j);
|
|
D2.set(std::max(D2.get(i, j), 0.0), i, j); // Ensure non-negative
|
|
D2.set(std::sqrt(D2.get(i, j)), i, j); // Take the square root of each
|
|
}
|
|
}
|
|
M++;
|
|
}
|
|
|
|
/*!
|
|
* Quick select implementation
|
|
* \fn void quickselect(std::vector<std::pair<DataType,IndexType>>&, int)
|
|
* \tparam DataType
|
|
* \tparam IndexType
|
|
* \param vec Vector of paire(distance, index) to partially sort over distance
|
|
* \param k The number of elements to sort-select
|
|
*/
|
|
template<typename DataType, typename IndexType>
|
|
void quickselect(std::vector<std::pair<DataType, IndexType>>& vec, int k) {
|
|
std::nth_element(
|
|
vec.begin(),
|
|
vec.begin() + k,
|
|
vec.end(),
|
|
[](const std::pair<DataType, IndexType>& a, const std::pair<DataType, IndexType>& b) {
|
|
return a.first < b.first;
|
|
});
|
|
vec.resize(k); // Keep only the k smallest elements
|
|
}
|
|
|
|
/*!
|
|
* \param C Is a MxD matrix (Corpus)
|
|
* \param Q Is a NxD matrix (Query)
|
|
* \param idx_offset The offset of the indexes for output (to match with the actual Corpus indexes)
|
|
* \param k The number of nearest neighbors needed
|
|
* \param idx Is the Nxk matrix with the k indexes of the C points, that are
|
|
* neighbors of the nth point of Q
|
|
* \param dst Is the Nxk matrix with the k distances to the C points of the nth
|
|
* point of Q
|
|
*/
|
|
template<typename MatrixD, typename MatrixI>
|
|
void knnsearch(MatrixD& C, MatrixD& Q, size_t idx_offset, size_t k, size_t m, MatrixI& idx, MatrixD& dst) {
|
|
|
|
using DstType = typename MatrixD::dataType;
|
|
using IdxType = typename MatrixI::dataType;
|
|
|
|
size_t M = C.rows();
|
|
size_t N = Q.rows();
|
|
|
|
mtx::Matrix<DstType> D(M, N);
|
|
|
|
pdist2(C, Q, D);
|
|
|
|
for (size_t j = 0; j < N; ++j) {
|
|
// Create a vector of pairs (distance, index) for the j-th query
|
|
std::vector<std::pair<DstType, IdxType>> dst_idx(M);
|
|
for (size_t i = 0; i < M; ++i) {
|
|
dst_idx[i] = {D.data()[i * N + j], i};
|
|
}
|
|
// Find the k smallest distances using quickSelectKSmallest
|
|
quickselect(dst_idx, k);
|
|
|
|
// Sort the k smallest results by distance for consistency
|
|
std::sort(dst_idx.begin(), dst_idx.end());
|
|
|
|
// Store the indices and distances
|
|
for (size_t i = 0; i < k; ++i) {
|
|
dst.set(dst_idx[i].first, j, i);
|
|
idx.set(dst_idx[i].second + idx_offset, j, i);
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
#endif /* V0_HPP_ */
|