324 lines
11 KiB
C++
324 lines
11 KiB
C++
|
// Copyright 2019-2021 The MathWorks, Inc.
|
||
|
|
||
|
/**
|
||
|
* @file
|
||
|
* @brief Utilities to find nearest neighbor in the tree for a given state
|
||
|
*/
|
||
|
|
||
|
|
||
|
#ifndef PLANNINGCODEGEN_NEARESTNEIGHBOR_HPP
|
||
|
#define PLANNINGCODEGEN_NEARESTNEIGHBOR_HPP
|
||
|
|
||
|
#include <vector>
|
||
|
#include <queue>
|
||
|
#include <algorithm>
|
||
|
#include <utility> // for std::pair
|
||
|
#include "planningcodegen_TreeNode.hpp"
|
||
|
#include "planningcodegen_DistanceMetric.hpp"
|
||
|
#include "planningcodegen_CommonCSMetric.hpp"
|
||
|
|
||
|
namespace nav
|
||
|
{
|
||
|
template <typename T>
|
||
|
class DistanceRadiusFunctor {
|
||
|
|
||
|
public:
|
||
|
DistanceRadiusFunctor() {}
|
||
|
|
||
|
bool operator()(const std::pair<T, TreeNode<T>*> &a, T val) const {
|
||
|
return a.first < val;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <typename T>
|
||
|
class AscendingNodeFunctor {
|
||
|
|
||
|
public:
|
||
|
AscendingNodeFunctor() {}
|
||
|
|
||
|
bool operator()(const std::pair<T, TreeNode<T>*> &a,
|
||
|
const std::pair<T, TreeNode<T>*> &b) {
|
||
|
return a.first < b.first;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// template class for nearest neighbor finding algorithms
|
||
|
/**
|
||
|
* @tparam T Data type
|
||
|
*/
|
||
|
template <typename T>
|
||
|
class NearestNeighborFinder
|
||
|
{
|
||
|
// tree is a friend
|
||
|
template <typename>
|
||
|
friend class SearchTree;
|
||
|
|
||
|
public:
|
||
|
/// NearestNeighborFinder constructor
|
||
|
NearestNeighborFinder(std::size_t nodeDim)
|
||
|
{
|
||
|
m_buildMetric = new CommonCSMetric<T>(nodeDim);
|
||
|
m_queryMetric = new CommonCSMetric<T>(nodeDim);
|
||
|
}
|
||
|
|
||
|
virtual ~NearestNeighborFinder()
|
||
|
{
|
||
|
purgeMetrics();
|
||
|
}
|
||
|
|
||
|
void purgeMetrics()
|
||
|
{
|
||
|
if (m_queryMetric != nullptr)
|
||
|
{
|
||
|
delete m_queryMetric;
|
||
|
m_queryMetric = nullptr;
|
||
|
}
|
||
|
|
||
|
if (m_buildMetric != nullptr)
|
||
|
{
|
||
|
delete m_buildMetric;
|
||
|
m_buildMetric = nullptr;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void setBuildMetric(DistanceMetric<T> * bm)
|
||
|
{
|
||
|
m_buildMetric = bm;
|
||
|
}
|
||
|
|
||
|
void setQueryMetric(DistanceMetric<T> * qm)
|
||
|
{
|
||
|
m_queryMetric = qm;
|
||
|
}
|
||
|
|
||
|
virtual void insert(TreeNode<T>* nodePtr) = 0;
|
||
|
virtual TreeNode<T>* nearestNeighbor(const std::vector<T>& state, T &distNN) = 0;
|
||
|
virtual std::vector<TreeNode<T>*> nearK(const std::vector<T>& state, std::size_t num) = 0;
|
||
|
virtual std::vector<TreeNode<T>*> near(const std::vector<T>& state, T dist) = 0;
|
||
|
|
||
|
DistanceMetric<T> * getBuildMetric() const
|
||
|
{
|
||
|
return m_buildMetric;
|
||
|
}
|
||
|
|
||
|
DistanceMetric<T> * getQueryMetric() const
|
||
|
{
|
||
|
return m_queryMetric;
|
||
|
}
|
||
|
|
||
|
protected:
|
||
|
DistanceMetric<T>* m_buildMetric;
|
||
|
DistanceMetric<T>* m_queryMetric;
|
||
|
};
|
||
|
|
||
|
template <typename T>
|
||
|
class ExhaustiveNN : public NearestNeighborFinder<T>
|
||
|
{
|
||
|
public:
|
||
|
|
||
|
ExhaustiveNN(std::size_t dim) : NearestNeighborFinder<T>(dim) { /*ExhaustiveNN constructor*/ }
|
||
|
void insert(TreeNode<T>* nodePtr) override
|
||
|
{
|
||
|
m_nodePtrs.push_back(nodePtr);
|
||
|
m_currentNodeStates.insert(m_currentNodeStates.end(), nodePtr->getState().begin(),
|
||
|
nodePtr->getState().end());
|
||
|
}
|
||
|
|
||
|
struct NaNDistanceComparator
|
||
|
{
|
||
|
bool operator()(const T& lhs, const T& rhs)
|
||
|
{
|
||
|
// any finite, infinite, or NaN distance is less than NaN
|
||
|
return (lhs < rhs) || std::isnan(rhs);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// return the nearest node (as node pointer) to a given state
|
||
|
TreeNode<T>* nearestNeighbor(const std::vector<T>& state, T &distNN) override
|
||
|
{
|
||
|
std::vector<T> dists = this->m_queryMetric->distance(m_currentNodeStates, state);
|
||
|
auto minIter = std::min_element(dists.begin(), dists.end(), NaNDistanceComparator());
|
||
|
|
||
|
distNN = *minIter;
|
||
|
auto nodePtrIter = m_nodePtrs.begin() + std::distance(dists.begin(), minIter);
|
||
|
return *nodePtrIter;
|
||
|
}
|
||
|
|
||
|
/// return the nearest node (as node pointer) to a given state, alternative implementation
|
||
|
TreeNode<T>* nearestNeighbor_alternative(const std::vector<T>& state, T &distNN)
|
||
|
{
|
||
|
T d = static_cast<T>(0.0);
|
||
|
T dNN = std::numeric_limits<T>::max();
|
||
|
TreeNode<T>* out = nullptr;
|
||
|
for (auto&& nodePtr : m_nodePtrs)
|
||
|
{
|
||
|
std::vector<T> dists = this->m_queryMetric->distance(nodePtr->getState(), state);
|
||
|
d = dists[0];
|
||
|
if (d < dNN)
|
||
|
{
|
||
|
dNN = d;
|
||
|
out = nodePtr;
|
||
|
}
|
||
|
}
|
||
|
distNN = dNN;
|
||
|
return out;
|
||
|
}
|
||
|
|
||
|
|
||
|
/// return the nearest K nodes (as node pointers) around the given state
|
||
|
std::vector<TreeNode<T>*> nearK(const std::vector<T>& state, std::size_t num) override
|
||
|
{
|
||
|
std::vector<TreeNode<T>*> out;
|
||
|
|
||
|
//TODO
|
||
|
auto comparator = [this, &state](TreeNode<T> * node1, TreeNode<T> * node2)->bool
|
||
|
{
|
||
|
std::vector<T> d1 = this->m_queryMetric->distance(state, node1->getState());
|
||
|
std::vector<T> d2 = this->m_queryMetric->distance(state, node2->getState());
|
||
|
return d1[0] > d2[0];
|
||
|
};
|
||
|
|
||
|
std::priority_queue<TreeNode<T> *, std::vector<TreeNode<T> *>, decltype(comparator) > pq(comparator);
|
||
|
for (auto&& nodePtr : m_nodePtrs)
|
||
|
{
|
||
|
pq.push(nodePtr);
|
||
|
}
|
||
|
|
||
|
for (std::size_t k = 0; k < std::min(num, m_nodePtrs.size()); k++)
|
||
|
{
|
||
|
out.push_back(pq.top());
|
||
|
pq.pop();
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
|
||
|
/// return the list of nodes (as node pointers) within a radius of a given state
|
||
|
std::vector<TreeNode<T>*> near(const std::vector<T>& state, T radius) override
|
||
|
{
|
||
|
// compute distance to all the existing nodes
|
||
|
std::vector<T> dists = this->m_queryMetric->distance(m_currentNodeStates, state );
|
||
|
std::size_t count = 0;
|
||
|
std::vector<std::pair<T, TreeNode<T>*>> distPairs{};
|
||
|
|
||
|
// convert to vector of (distance, node_pointer) pairs
|
||
|
|
||
|
//TODO
|
||
|
std::transform(dists.begin(), dists.end(), std::back_inserter(distPairs),
|
||
|
[this, &count](T d) -> std::pair<T, TreeNode<T>*> {
|
||
|
count++;
|
||
|
return std::make_pair(d, this->m_nodePtrs[count - 1]); });
|
||
|
|
||
|
// sort vector of pairs by distance (ascending)
|
||
|
std::sort(distPairs.begin(), distPairs.end(), nav::AscendingNodeFunctor<T>());
|
||
|
|
||
|
// retrieve those with distance smaller than radius
|
||
|
auto lbIt = std::lower_bound(distPairs.begin(), distPairs.end(), radius,
|
||
|
nav::DistanceRadiusFunctor<T>());
|
||
|
|
||
|
std::vector<TreeNode<T>*> out;
|
||
|
for (auto it = distPairs.begin(); it < lbIt; it++)
|
||
|
{
|
||
|
out.push_back((*it).second);
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
|
||
|
/// return the list of nodes (as node pointers) within a radius of a given state (alternative implementation)
|
||
|
std::vector<TreeNode<T>*> near_alternative(const std::vector<T>& state, T radius)
|
||
|
{
|
||
|
std::vector<TreeNode<T>*> out;
|
||
|
T defaultDist = static_cast<T>(-1);
|
||
|
//TODO
|
||
|
auto comparator = [this, &state](std::pair<TreeNode<T>*, T>& pair1, std::pair<TreeNode<T>*, T>& pair2)->bool
|
||
|
{
|
||
|
if (pair1.second < 0) // compute distance if it has not been done yet
|
||
|
{
|
||
|
std::vector<T> d1 = this->m_queryMetric->distance(pair1.first->getState(), state);
|
||
|
pair1.second = d1[0];
|
||
|
}
|
||
|
if (pair2.second < 0)
|
||
|
{
|
||
|
std::vector<T> d2 = this->m_queryMetric->distance(pair2.first->getState(), state);
|
||
|
pair2.second = d2[0];
|
||
|
}
|
||
|
return pair1.second > pair2.second;
|
||
|
};
|
||
|
|
||
|
// each pair contains the tree node pointer and the node's distance to given state
|
||
|
std::priority_queue<std::pair<TreeNode<T>*, T>, std::vector<std::pair<TreeNode<T>*, T>>, decltype(comparator) > pq(comparator);
|
||
|
for (auto&& nodePtr : m_nodePtrs)
|
||
|
{
|
||
|
pq.push(std::make_pair(nodePtr, defaultDist) );
|
||
|
}
|
||
|
|
||
|
while (!pq.empty() && pq.top().second < radius)
|
||
|
{
|
||
|
out.push_back(pq.top().first);
|
||
|
pq.pop();
|
||
|
}
|
||
|
return out;
|
||
|
}
|
||
|
|
||
|
/// find the ID of the nearest neighbor node, used for test only
|
||
|
std::size_t nearestNeighborIdx(const std::vector<T>& state, T& distNN)
|
||
|
{
|
||
|
T d{};
|
||
|
T dNN = std::numeric_limits<T>::max();
|
||
|
std::size_t idxNN = 0;
|
||
|
for (std::size_t i = 0; i < m_nodePtrs.size(); i++)
|
||
|
{
|
||
|
std::vector<T> dists = this->m_queryMetric->distance(m_nodePtrs[i]->getState(), state);
|
||
|
d = dists[0];
|
||
|
if (d < dNN)
|
||
|
{
|
||
|
dNN = d;
|
||
|
idxNN = i;
|
||
|
}
|
||
|
}
|
||
|
distNN = dNN;
|
||
|
return idxNN;
|
||
|
}
|
||
|
|
||
|
|
||
|
/// find the IDs of the near K neighboring nodes for a given state, used for test only
|
||
|
std::vector<std::size_t> nearKIdx(const std::vector<T>& state, const std::size_t& num)
|
||
|
{
|
||
|
std::vector<std::size_t> indices{};
|
||
|
|
||
|
//TODO
|
||
|
auto comparator = [this, &state](std::size_t id1, std::size_t id2)->bool
|
||
|
{
|
||
|
auto node1 = this->m_nodePtrs[id1];
|
||
|
auto node2 = this->m_nodePtrs[id2];
|
||
|
std::vector<T> d1 = this->m_queryMetric->distance(state, node1->getState());
|
||
|
std::vector<T> d2 = this->m_queryMetric->distance(state, node2->getState());
|
||
|
return d1[0] > d2[0];
|
||
|
};
|
||
|
|
||
|
std::priority_queue<std::size_t, std::vector<std::size_t>, decltype(comparator) > pq(comparator);
|
||
|
for (std::size_t i = 0; i < m_nodePtrs.size(); i++)
|
||
|
{
|
||
|
pq.push(i);
|
||
|
}
|
||
|
for (std::size_t k = 0; k < std::min(num, m_nodePtrs.size()); k++)
|
||
|
{
|
||
|
indices.push_back(pq.top());
|
||
|
pq.pop();
|
||
|
}
|
||
|
return indices;
|
||
|
}
|
||
|
|
||
|
|
||
|
protected:
|
||
|
/// list of node pointers to facilitate nearest neighbor search
|
||
|
std::vector<TreeNode<T>*> m_nodePtrs;
|
||
|
|
||
|
/// states from all the nodes known by NN flattened into a 1-d vector
|
||
|
std::vector<T> m_currentNodeStates;
|
||
|
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
#endif
|