Thread-safe refactoring of the class KdTree.

Removed methods:
void setMaxNofNeighbors(unsigned int k);
inline int getNofFoundNeighbors(void);
inline const VectorType& getNeighbor(int i);
inline unsigned int getNeighborId(int i);
inline float getNeighborSquaredDistance(int i);

Added methods:
void doQueryDist(const VectorType& queryPoint, float dist, std::vector<unsigned int>& points, std::vector<Scalar>& sqrareDists);
void doQueryClosest(const VectorType& queryPoint, unsigned int& index, Scalar& dist);

Changed methods:
void doQueryK(const VectorType& queryPoint,  int k, PriorityQueue& mNeighborQueue);
This commit is contained in:
Gianpaolo Palma 2014-07-11 11:52:52 +00:00
parent 0491ceedeb
commit 31fb567321
1 changed files with 457 additions and 323 deletions
vcg/space/index/kdtree

View File

@ -1,18 +1,20 @@
#ifndef KDTREE_H
#define KDTREE_H
#ifndef KDTREE_VCG_H
#define KDTREE_VCG_H
#include <vcg/space/point3.h>
#include <vcg/space/box3.h>
#include <vcg/space/index/kdtree/priorityqueue.h>
#include "../../point3.h"
#include "../../box3.h"
#include "mlsutils.h"
#include "priorityqueue.h"
#include <vector>
#include <limits>
#include <iostream>
template<typename _DataType>
class ConstDataWrapper
{
public:
namespace vcg {
template<typename _DataType>
class ConstDataWrapper
{
public:
typedef _DataType DataType;
inline ConstDataWrapper()
: mpData(0), mStride(0), mSize(0)
@ -25,42 +27,45 @@ public:
return *reinterpret_cast<const DataType*>(mpData + i*mStride);
}
inline size_t size() const { return mSize; }
protected:
protected:
const unsigned char* mpData;
int mStride;
size_t mSize;
};
};
template<class StdVectorType>
class VectorConstDataWrapper :public ConstDataWrapper<typename StdVectorType::value_type>
{
public:
template<class StdVectorType>
class VectorConstDataWrapper :public ConstDataWrapper<typename StdVectorType::value_type>
{
public:
inline VectorConstDataWrapper(StdVectorType &vec):
ConstDataWrapper<typename StdVectorType::value_type> ( &(vec[0]), vec.size(), sizeof(typename StdVectorType::value_type))
{}
};
};
template<class MeshType>
class VertexConstDataWrapper :public ConstDataWrapper<typename MeshType::CoordType>
{
public:
template<class MeshType>
class VertexConstDataWrapper :public ConstDataWrapper<typename MeshType::CoordType>
{
public:
inline VertexConstDataWrapper(MeshType &m):
ConstDataWrapper<typename MeshType::CoordType> ( &(m.vert[0].P()), m.vert.size(), sizeof(typename MeshType::VertexType))
{}
};
};
/**
* This class allows to create a Kd-Tree thought to perform the k-nearest neighbour query
/**
* This class allows to create a Kd-Tree thought to perform the neighbour query (radius search, knn-nearest serach and closest search).
* The class implemetantion is thread-safe.
*/
template<typename _Scalar>
class KdTree
{
public:
template<typename _Scalar>
class KdTree
{
public:
typedef _Scalar Scalar;
typedef vcg::Point3<Scalar> VectorType;
typedef vcg::Box3<Scalar> AxisAlignedBoxType;
typedef HeapMaxPriorityQueue<int, Scalar> PriorityQueue;
struct Node
{
union {
@ -84,22 +89,19 @@ public:
inline const NodeList& _getNodes(void) { return mNodes; }
inline const std::vector<VectorType>& _getPoints(void) { return mPoints; }
void setMaxNofNeighbors(unsigned int k);
inline int getNofFoundNeighbors(void) { return mNeighborQueue.getNofElements(); }
inline const VectorType& getNeighbor(int i) { return mPoints[ mNeighborQueue.getIndex(i) ]; }
inline unsigned int getNeighborId(int i) { return mIndices[mNeighborQueue.getIndex(i)]; }
inline float getNeighborSquaredDistance(int i) { return mNeighborQueue.getWeight(i); }
public:
public:
KdTree(const ConstDataWrapper<VectorType>& points, unsigned int nofPointsPerCell = 16, unsigned int maxDepth = 64);
~KdTree();
void doQueryK(const VectorType& p);
void doQueryK(const VectorType& queryPoint, int k, PriorityQueue& mNeighborQueue);
protected:
void doQueryDist(const VectorType& queryPoint, float dist, std::vector<unsigned int>& points, std::vector<Scalar>& sqrareDists);
void doQueryClosest(const VectorType& queryPoint, unsigned int& index, Scalar& dist);
protected:
// element of the stack
struct QueryNode
@ -116,21 +118,18 @@ protected:
void createTree(unsigned int nodeId, unsigned int start, unsigned int end, unsigned int level, unsigned int targetCellsize, unsigned int targetMaxDepth);
protected:
protected:
AxisAlignedBoxType mAABB; //BoundingBox
NodeList mNodes; //kd-tree nodes
std::vector<VectorType> mPoints; //points read from the input DataWrapper
std::vector<int> mIndices; //points indices
std::vector<unsigned int> mIndices; //points indices
};
HeapMaxPriorityQueue<int,Scalar> mNeighborQueue; //used to perform the knn-query
QueryNode mNodeStack[64]; //used in the implementation of the knn-query
};
template<typename Scalar>
KdTree<Scalar>::KdTree(const ConstDataWrapper<VectorType>& points, unsigned int nofPointsPerCell, unsigned int maxDepth)
template<typename Scalar>
KdTree<Scalar>::KdTree(const ConstDataWrapper<VectorType>& points, unsigned int nofPointsPerCell, unsigned int maxDepth)
: mPoints(points.size()), mIndices(points.size())
{
{
// compute the AABB of the input
mPoints[0] = points[0];
mAABB.Set(mPoints[0]);
@ -147,20 +146,15 @@ KdTree<Scalar>::KdTree(const ConstDataWrapper<VectorType>& points, unsigned int
mNodes.resize(1);
mNodes.back().leaf = 0;
createTree(0, 0, mPoints.size(), 1, nofPointsPerCell, maxDepth);
}
}
template<typename Scalar>
KdTree<Scalar>::~KdTree()
{
}
template<typename Scalar>
KdTree<Scalar>::~KdTree()
{
}
template<typename Scalar>
void KdTree<Scalar>::setMaxNofNeighbors(unsigned int k)
{
mNeighborQueue.setMaxSize(k);
}
/** Performs the kNN query.
/** Performs the kNN query.
*
* This algorithm uses the simple distance to the split plane to prune nodes.
* A more elaborated approach consists to track the closest corner of the cell
@ -173,15 +167,17 @@ void KdTree<Scalar>::setMaxNofNeighbors(unsigned int k)
* But, again, priority queue insertions and deletions are quite involved, and therefore
* a simple stack is by far much faster.
*
* The result of the query, the k-nearest neighbors, are internally stored into a stack, where the
* topmost element [0] is NOT the nearest but the farthest!! (they are not sorted but arranged into a heap)
* The result of the query, the k-nearest neighbors, are stored into the stack mNeighborQueue, where the
* topmost element [0] is NOT the nearest but the farthest!! (they are not sorted but arranged into a heap).
*/
template<typename Scalar>
void KdTree<Scalar>::doQueryK(const VectorType& queryPoint)
{
template<typename Scalar>
void KdTree<Scalar>::doQueryK(const VectorType& queryPoint, int k, PriorityQueue& mNeighborQueue)
{
mNeighborQueue.setMaxSize(k);
mNeighborQueue.init();
mNeighborQueue.insert(0xffffffff, std::numeric_limits<Scalar>::max());
QueryNode mNodeStack[64];
mNodeStack[0].nodeId = 0;
mNodeStack[0].sq = 0.f;
unsigned int count = 1;
@ -208,7 +204,7 @@ void KdTree<Scalar>::doQueryK(const VectorType& queryPoint)
unsigned int end = node.start+node.size;
//adding the element of the leaf to the heap
for (unsigned int i=node.start ; i<end ; ++i)
mNeighborQueue.insert(i, vcg::SquaredNorm(queryPoint - mPoints[i]));
mNeighborQueue.insert(mIndices[i], vcg::SquaredNorm(queryPoint - mPoints[i]));
}
//otherwise, if we're not on a leaf
else
@ -242,16 +238,149 @@ void KdTree<Scalar>::doQueryK(const VectorType& queryPoint)
--count;
}
}
}
}
/**
/** Performs the distance query.
*
* The result of the query, all the points within the distance dist form the query point, is the vector of the indeces
* and the vector of the squared distances from the query point.
*/
template<typename Scalar>
void KdTree<Scalar>::doQueryDist(const VectorType& queryPoint, float dist, std::vector<unsigned int>& points, std::vector<Scalar>& sqrareDists)
{
QueryNode mNodeStack[64];
mNodeStack[0].nodeId = 0;
mNodeStack[0].sq = 0.f;
unsigned int count = 1;
float sqrareDist = dist*dist;
while (count)
{
QueryNode& qnode = mNodeStack[count-1];
Node & node = mNodes[qnode.nodeId];
if (qnode.sq < sqrareDist)
{
if (node.leaf)
{
--count; // pop
unsigned int end = node.start+node.size;
for (unsigned int i=node.start ; i<end ; ++i)
{
float pointSquareDist = vcg::SquaredNorm(queryPoint - mPoints[i]);
if (pointSquareDist < sqrareDist)
{
points.push_back(mIndices[i]);
sqrareDists.push_back(pointSquareDist);
}
}
}
else
{
// replace the stack top by the farthest and push the closest
float new_off = queryPoint[node.dim] - node.splitValue;
if (new_off < 0.)
{
mNodeStack[count].nodeId = node.firstChildId;
qnode.nodeId = node.firstChildId+1;
}
else
{
mNodeStack[count].nodeId = node.firstChildId+1;
qnode.nodeId = node.firstChildId;
}
mNodeStack[count].sq = qnode.sq;
qnode.sq = new_off*new_off;
++count;
}
}
else
{
// pop
--count;
}
}
}
/** Searchs the closest point.
*
* The result of the query, the closest point to the query point, is the index of the point and
* and the squared distance from the query point.
*/
template<typename Scalar>
void KdTree<Scalar>::doQueryClosest(const VectorType& queryPoint, unsigned int& index, Scalar& dist)
{
QueryNode mNodeStack[64];
mNodeStack[0].nodeId = 0;
mNodeStack[0].sq = 0.f;
unsigned int count = 1;
int minIndex = mIndices.size() / 2;
Scalar minDist = vcg::SquaredNorm(queryPoint - mPoints[minIndex]);
minIndex = mIndices[minIndex];
while (count)
{
QueryNode& qnode = mNodeStack[count-1];
Node & node = mNodes[qnode.nodeId];
if (qnode.sq < minDist)
{
if (node.leaf)
{
--count; // pop
unsigned int end = node.start+node.size;
for (unsigned int i=node.start ; i<end ; ++i)
{
float pointSquareDist = vcg::SquaredNorm(queryPoint - mPoints[i]);
if (pointSquareDist < minDist)
{
minDist = pointSquareDist;
minIndex = mIndices[i];
}
}
}
else
{
// replace the stack top by the farthest and push the closest
float new_off = queryPoint[node.dim] - node.splitValue;
if (new_off < 0.)
{
mNodeStack[count].nodeId = node.firstChildId;
qnode.nodeId = node.firstChildId+1;
}
else
{
mNodeStack[count].nodeId = node.firstChildId+1;
qnode.nodeId = node.firstChildId;
}
mNodeStack[count].sq = qnode.sq;
qnode.sq = new_off*new_off;
++count;
}
}
else
{
// pop
--count;
}
}
index = minIndex;
dist = minDist;
}
/**
* Split the subarray between start and end in two part, one with the elements less than splitValue,
* the other with the elements greater or equal than splitValue. The elements are compared
* using the "dim" coordinate [0 = x, 1 = y, 2 = z].
*/
template<typename Scalar>
unsigned int KdTree<Scalar>::split(int start, int end, unsigned int dim, float splitValue)
{
template<typename Scalar>
unsigned int KdTree<Scalar>::split(int start, int end, unsigned int dim, float splitValue)
{
int l(start), r(end-1);
for ( ; l<r ; ++l, --r)
{
@ -266,9 +395,9 @@ unsigned int KdTree<Scalar>::split(int start, int end, unsigned int dim, float s
}
//returns the index of the first element on the second part
return (mPoints[l][dim] < splitValue ? l+1 : l);
}
}
/** recursively builds the kdtree
/** recursively builds the kdtree
*
* The heuristic is the following:
* - if the number of points in the node is lower than targetCellsize then make a leaf
@ -285,9 +414,9 @@ unsigned int KdTree<Scalar>::split(int start, int end, unsigned int dim, float s
* to prune only about 10% of the leaves, but the overhead of this pruning (ball/ABBB intersection)
* is more expensive than the gain it provides and the memory consumption is x4 higher !
*/
template<typename Scalar>
void KdTree<Scalar>::createTree(unsigned int nodeId, unsigned int start, unsigned int end, unsigned int level, unsigned int targetCellSize, unsigned int targetMaxDepth)
{
template<typename Scalar>
void KdTree<Scalar>::createTree(unsigned int nodeId, unsigned int start, unsigned int end, unsigned int level, unsigned int targetCellSize, unsigned int targetMaxDepth)
{
//select the first node
Node& node = mNodes[nodeId];
AxisAlignedBoxType aabb;
@ -301,7 +430,12 @@ void KdTree<Scalar>::createTree(unsigned int nodeId, unsigned int start, unsigne
VectorType diag = aabb.max - aabb.min;
//the split "dim" is the dimension of the box with the biggest value
unsigned int dim = vcg::MaxCoeffId(diag);
unsigned int dim;
if (diag.X() > diag.Y())
dim = diag.X() > diag.Z() ? 0 : 2;
else
dim = diag.Y() > diag.Z() ? 1 : 2;
node.dim = dim;
//we divide the bounding box in 2 partitions, considering the average of the "dim" dimension
node.splitValue = Scalar(0.5*(aabb.max[dim] + aabb.min[dim]));
@ -346,7 +480,7 @@ void KdTree<Scalar>::createTree(unsigned int nodeId, unsigned int start, unsigne
createTree(childId, midId, end, level+1, targetCellSize, targetMaxDepth);
}
}
}
}
#endif