DYTSrouce/Tool/matlab/include/nav/planningcodegen_NearestNeighbor.hpp
2024-11-22 23:19:31 +08:00

324 lines
11 KiB
C++

// Copyright 2019-2021 The MathWorks, Inc.
/**
* @file
* @brief Utilities to find nearest neighbor in the tree for a given state
*/
#ifndef PLANNINGCODEGEN_NEARESTNEIGHBOR_HPP
#define PLANNINGCODEGEN_NEARESTNEIGHBOR_HPP
#include <vector>
#include <queue>
#include <algorithm>
#include <utility> // for std::pair
#include "planningcodegen_TreeNode.hpp"
#include "planningcodegen_DistanceMetric.hpp"
#include "planningcodegen_CommonCSMetric.hpp"
namespace nav
{
template <typename T>
class DistanceRadiusFunctor {
public:
DistanceRadiusFunctor() {}
bool operator()(const std::pair<T, TreeNode<T>*> &a, T val) const {
return a.first < val;
}
};
template <typename T>
class AscendingNodeFunctor {
public:
AscendingNodeFunctor() {}
bool operator()(const std::pair<T, TreeNode<T>*> &a,
const std::pair<T, TreeNode<T>*> &b) {
return a.first < b.first;
}
};
/// template class for nearest neighbor finding algorithms
/**
* @tparam T Data type
*/
template <typename T>
class NearestNeighborFinder
{
// tree is a friend
template <typename>
friend class SearchTree;
public:
/// NearestNeighborFinder constructor
NearestNeighborFinder(std::size_t nodeDim)
{
m_buildMetric = new CommonCSMetric<T>(nodeDim);
m_queryMetric = new CommonCSMetric<T>(nodeDim);
}
virtual ~NearestNeighborFinder()
{
purgeMetrics();
}
void purgeMetrics()
{
if (m_queryMetric != nullptr)
{
delete m_queryMetric;
m_queryMetric = nullptr;
}
if (m_buildMetric != nullptr)
{
delete m_buildMetric;
m_buildMetric = nullptr;
}
}
void setBuildMetric(DistanceMetric<T> * bm)
{
m_buildMetric = bm;
}
void setQueryMetric(DistanceMetric<T> * qm)
{
m_queryMetric = qm;
}
virtual void insert(TreeNode<T>* nodePtr) = 0;
virtual TreeNode<T>* nearestNeighbor(const std::vector<T>& state, T &distNN) = 0;
virtual std::vector<TreeNode<T>*> nearK(const std::vector<T>& state, std::size_t num) = 0;
virtual std::vector<TreeNode<T>*> near(const std::vector<T>& state, T dist) = 0;
DistanceMetric<T> * getBuildMetric() const
{
return m_buildMetric;
}
DistanceMetric<T> * getQueryMetric() const
{
return m_queryMetric;
}
protected:
DistanceMetric<T>* m_buildMetric;
DistanceMetric<T>* m_queryMetric;
};
template <typename T>
class ExhaustiveNN : public NearestNeighborFinder<T>
{
public:
ExhaustiveNN(std::size_t dim) : NearestNeighborFinder<T>(dim) { /*ExhaustiveNN constructor*/ }
void insert(TreeNode<T>* nodePtr) override
{
m_nodePtrs.push_back(nodePtr);
m_currentNodeStates.insert(m_currentNodeStates.end(), nodePtr->getState().begin(),
nodePtr->getState().end());
}
struct NaNDistanceComparator
{
bool operator()(const T& lhs, const T& rhs)
{
// any finite, infinite, or NaN distance is less than NaN
return (lhs < rhs) || std::isnan(rhs);
}
};
/// return the nearest node (as node pointer) to a given state
TreeNode<T>* nearestNeighbor(const std::vector<T>& state, T &distNN) override
{
std::vector<T> dists = this->m_queryMetric->distance(m_currentNodeStates, state);
auto minIter = std::min_element(dists.begin(), dists.end(), NaNDistanceComparator());
distNN = *minIter;
auto nodePtrIter = m_nodePtrs.begin() + std::distance(dists.begin(), minIter);
return *nodePtrIter;
}
/// return the nearest node (as node pointer) to a given state, alternative implementation
TreeNode<T>* nearestNeighbor_alternative(const std::vector<T>& state, T &distNN)
{
T d = static_cast<T>(0.0);
T dNN = std::numeric_limits<T>::max();
TreeNode<T>* out = nullptr;
for (auto&& nodePtr : m_nodePtrs)
{
std::vector<T> dists = this->m_queryMetric->distance(nodePtr->getState(), state);
d = dists[0];
if (d < dNN)
{
dNN = d;
out = nodePtr;
}
}
distNN = dNN;
return out;
}
/// return the nearest K nodes (as node pointers) around the given state
std::vector<TreeNode<T>*> nearK(const std::vector<T>& state, std::size_t num) override
{
std::vector<TreeNode<T>*> out;
//TODO
auto comparator = [this, &state](TreeNode<T> * node1, TreeNode<T> * node2)->bool
{
std::vector<T> d1 = this->m_queryMetric->distance(state, node1->getState());
std::vector<T> d2 = this->m_queryMetric->distance(state, node2->getState());
return d1[0] > d2[0];
};
std::priority_queue<TreeNode<T> *, std::vector<TreeNode<T> *>, decltype(comparator) > pq(comparator);
for (auto&& nodePtr : m_nodePtrs)
{
pq.push(nodePtr);
}
for (std::size_t k = 0; k < std::min(num, m_nodePtrs.size()); k++)
{
out.push_back(pq.top());
pq.pop();
}
return out;
}
/// return the list of nodes (as node pointers) within a radius of a given state
std::vector<TreeNode<T>*> near(const std::vector<T>& state, T radius) override
{
// compute distance to all the existing nodes
std::vector<T> dists = this->m_queryMetric->distance(m_currentNodeStates, state );
std::size_t count = 0;
std::vector<std::pair<T, TreeNode<T>*>> distPairs{};
// convert to vector of (distance, node_pointer) pairs
//TODO
std::transform(dists.begin(), dists.end(), std::back_inserter(distPairs),
[this, &count](T d) -> std::pair<T, TreeNode<T>*> {
count++;
return std::make_pair(d, this->m_nodePtrs[count - 1]); });
// sort vector of pairs by distance (ascending)
std::sort(distPairs.begin(), distPairs.end(), nav::AscendingNodeFunctor<T>());
// retrieve those with distance smaller than radius
auto lbIt = std::lower_bound(distPairs.begin(), distPairs.end(), radius,
nav::DistanceRadiusFunctor<T>());
std::vector<TreeNode<T>*> out;
for (auto it = distPairs.begin(); it < lbIt; it++)
{
out.push_back((*it).second);
}
return out;
}
/// return the list of nodes (as node pointers) within a radius of a given state (alternative implementation)
std::vector<TreeNode<T>*> near_alternative(const std::vector<T>& state, T radius)
{
std::vector<TreeNode<T>*> out;
T defaultDist = static_cast<T>(-1);
//TODO
auto comparator = [this, &state](std::pair<TreeNode<T>*, T>& pair1, std::pair<TreeNode<T>*, T>& pair2)->bool
{
if (pair1.second < 0) // compute distance if it has not been done yet
{
std::vector<T> d1 = this->m_queryMetric->distance(pair1.first->getState(), state);
pair1.second = d1[0];
}
if (pair2.second < 0)
{
std::vector<T> d2 = this->m_queryMetric->distance(pair2.first->getState(), state);
pair2.second = d2[0];
}
return pair1.second > pair2.second;
};
// each pair contains the tree node pointer and the node's distance to given state
std::priority_queue<std::pair<TreeNode<T>*, T>, std::vector<std::pair<TreeNode<T>*, T>>, decltype(comparator) > pq(comparator);
for (auto&& nodePtr : m_nodePtrs)
{
pq.push(std::make_pair(nodePtr, defaultDist) );
}
while (!pq.empty() && pq.top().second < radius)
{
out.push_back(pq.top().first);
pq.pop();
}
return out;
}
/// find the ID of the nearest neighbor node, used for test only
std::size_t nearestNeighborIdx(const std::vector<T>& state, T& distNN)
{
T d{};
T dNN = std::numeric_limits<T>::max();
std::size_t idxNN = 0;
for (std::size_t i = 0; i < m_nodePtrs.size(); i++)
{
std::vector<T> dists = this->m_queryMetric->distance(m_nodePtrs[i]->getState(), state);
d = dists[0];
if (d < dNN)
{
dNN = d;
idxNN = i;
}
}
distNN = dNN;
return idxNN;
}
/// find the IDs of the near K neighboring nodes for a given state, used for test only
std::vector<std::size_t> nearKIdx(const std::vector<T>& state, const std::size_t& num)
{
std::vector<std::size_t> indices{};
//TODO
auto comparator = [this, &state](std::size_t id1, std::size_t id2)->bool
{
auto node1 = this->m_nodePtrs[id1];
auto node2 = this->m_nodePtrs[id2];
std::vector<T> d1 = this->m_queryMetric->distance(state, node1->getState());
std::vector<T> d2 = this->m_queryMetric->distance(state, node2->getState());
return d1[0] > d2[0];
};
std::priority_queue<std::size_t, std::vector<std::size_t>, decltype(comparator) > pq(comparator);
for (std::size_t i = 0; i < m_nodePtrs.size(); i++)
{
pq.push(i);
}
for (std::size_t k = 0; k < std::min(num, m_nodePtrs.size()); k++)
{
indices.push_back(pq.top());
pq.pop();
}
return indices;
}
protected:
/// list of node pointers to facilitate nearest neighbor search
std::vector<TreeNode<T>*> m_nodePtrs;
/// states from all the nodes known by NN flattened into a 1-d vector
std::vector<T> m_currentNodeStates;
};
}
#endif