// Copyright 2019-2021 The MathWorks, Inc. /** * @file * @brief Utilities to find nearest neighbor in the tree for a given state */ #ifndef PLANNINGCODEGEN_NEARESTNEIGHBOR_HPP #define PLANNINGCODEGEN_NEARESTNEIGHBOR_HPP #include #include #include #include // for std::pair #include "planningcodegen_TreeNode.hpp" #include "planningcodegen_DistanceMetric.hpp" #include "planningcodegen_CommonCSMetric.hpp" namespace nav { template class DistanceRadiusFunctor { public: DistanceRadiusFunctor() {} bool operator()(const std::pair*> &a, T val) const { return a.first < val; } }; template class AscendingNodeFunctor { public: AscendingNodeFunctor() {} bool operator()(const std::pair*> &a, const std::pair*> &b) { return a.first < b.first; } }; /// template class for nearest neighbor finding algorithms /** * @tparam T Data type */ template class NearestNeighborFinder { // tree is a friend template friend class SearchTree; public: /// NearestNeighborFinder constructor NearestNeighborFinder(std::size_t nodeDim) { m_buildMetric = new CommonCSMetric(nodeDim); m_queryMetric = new CommonCSMetric(nodeDim); } virtual ~NearestNeighborFinder() { purgeMetrics(); } void purgeMetrics() { if (m_queryMetric != nullptr) { delete m_queryMetric; m_queryMetric = nullptr; } if (m_buildMetric != nullptr) { delete m_buildMetric; m_buildMetric = nullptr; } } void setBuildMetric(DistanceMetric * bm) { m_buildMetric = bm; } void setQueryMetric(DistanceMetric * qm) { m_queryMetric = qm; } virtual void insert(TreeNode* nodePtr) = 0; virtual TreeNode* nearestNeighbor(const std::vector& state, T &distNN) = 0; virtual std::vector*> nearK(const std::vector& state, std::size_t num) = 0; virtual std::vector*> near(const std::vector& state, T dist) = 0; DistanceMetric * getBuildMetric() const { return m_buildMetric; } DistanceMetric * getQueryMetric() const { return m_queryMetric; } protected: DistanceMetric* m_buildMetric; DistanceMetric* m_queryMetric; }; template class ExhaustiveNN : public NearestNeighborFinder { public: ExhaustiveNN(std::size_t dim) : NearestNeighborFinder(dim) { /*ExhaustiveNN constructor*/ } void insert(TreeNode* nodePtr) override { m_nodePtrs.push_back(nodePtr); m_currentNodeStates.insert(m_currentNodeStates.end(), nodePtr->getState().begin(), nodePtr->getState().end()); } struct NaNDistanceComparator { bool operator()(const T& lhs, const T& rhs) { // any finite, infinite, or NaN distance is less than NaN return (lhs < rhs) || std::isnan(rhs); } }; /// return the nearest node (as node pointer) to a given state TreeNode* nearestNeighbor(const std::vector& state, T &distNN) override { std::vector dists = this->m_queryMetric->distance(m_currentNodeStates, state); auto minIter = std::min_element(dists.begin(), dists.end(), NaNDistanceComparator()); distNN = *minIter; auto nodePtrIter = m_nodePtrs.begin() + std::distance(dists.begin(), minIter); return *nodePtrIter; } /// return the nearest node (as node pointer) to a given state, alternative implementation TreeNode* nearestNeighbor_alternative(const std::vector& state, T &distNN) { T d = static_cast(0.0); T dNN = std::numeric_limits::max(); TreeNode* out = nullptr; for (auto&& nodePtr : m_nodePtrs) { std::vector dists = this->m_queryMetric->distance(nodePtr->getState(), state); d = dists[0]; if (d < dNN) { dNN = d; out = nodePtr; } } distNN = dNN; return out; } /// return the nearest K nodes (as node pointers) around the given state std::vector*> nearK(const std::vector& state, std::size_t num) override { std::vector*> out; //TODO auto comparator = [this, &state](TreeNode * node1, TreeNode * node2)->bool { std::vector d1 = this->m_queryMetric->distance(state, node1->getState()); std::vector d2 = this->m_queryMetric->distance(state, node2->getState()); return d1[0] > d2[0]; }; std::priority_queue *, std::vector *>, decltype(comparator) > pq(comparator); for (auto&& nodePtr : m_nodePtrs) { pq.push(nodePtr); } for (std::size_t k = 0; k < std::min(num, m_nodePtrs.size()); k++) { out.push_back(pq.top()); pq.pop(); } return out; } /// return the list of nodes (as node pointers) within a radius of a given state std::vector*> near(const std::vector& state, T radius) override { // compute distance to all the existing nodes std::vector dists = this->m_queryMetric->distance(m_currentNodeStates, state ); std::size_t count = 0; std::vector*>> distPairs{}; // convert to vector of (distance, node_pointer) pairs //TODO std::transform(dists.begin(), dists.end(), std::back_inserter(distPairs), [this, &count](T d) -> std::pair*> { count++; return std::make_pair(d, this->m_nodePtrs[count - 1]); }); // sort vector of pairs by distance (ascending) std::sort(distPairs.begin(), distPairs.end(), nav::AscendingNodeFunctor()); // retrieve those with distance smaller than radius auto lbIt = std::lower_bound(distPairs.begin(), distPairs.end(), radius, nav::DistanceRadiusFunctor()); std::vector*> out; for (auto it = distPairs.begin(); it < lbIt; it++) { out.push_back((*it).second); } return out; } /// return the list of nodes (as node pointers) within a radius of a given state (alternative implementation) std::vector*> near_alternative(const std::vector& state, T radius) { std::vector*> out; T defaultDist = static_cast(-1); //TODO auto comparator = [this, &state](std::pair*, T>& pair1, std::pair*, T>& pair2)->bool { if (pair1.second < 0) // compute distance if it has not been done yet { std::vector d1 = this->m_queryMetric->distance(pair1.first->getState(), state); pair1.second = d1[0]; } if (pair2.second < 0) { std::vector d2 = this->m_queryMetric->distance(pair2.first->getState(), state); pair2.second = d2[0]; } return pair1.second > pair2.second; }; // each pair contains the tree node pointer and the node's distance to given state std::priority_queue*, T>, std::vector*, T>>, decltype(comparator) > pq(comparator); for (auto&& nodePtr : m_nodePtrs) { pq.push(std::make_pair(nodePtr, defaultDist) ); } while (!pq.empty() && pq.top().second < radius) { out.push_back(pq.top().first); pq.pop(); } return out; } /// find the ID of the nearest neighbor node, used for test only std::size_t nearestNeighborIdx(const std::vector& state, T& distNN) { T d{}; T dNN = std::numeric_limits::max(); std::size_t idxNN = 0; for (std::size_t i = 0; i < m_nodePtrs.size(); i++) { std::vector dists = this->m_queryMetric->distance(m_nodePtrs[i]->getState(), state); d = dists[0]; if (d < dNN) { dNN = d; idxNN = i; } } distNN = dNN; return idxNN; } /// find the IDs of the near K neighboring nodes for a given state, used for test only std::vector nearKIdx(const std::vector& state, const std::size_t& num) { std::vector indices{}; //TODO auto comparator = [this, &state](std::size_t id1, std::size_t id2)->bool { auto node1 = this->m_nodePtrs[id1]; auto node2 = this->m_nodePtrs[id2]; std::vector d1 = this->m_queryMetric->distance(state, node1->getState()); std::vector d2 = this->m_queryMetric->distance(state, node2->getState()); return d1[0] > d2[0]; }; std::priority_queue, decltype(comparator) > pq(comparator); for (std::size_t i = 0; i < m_nodePtrs.size(); i++) { pq.push(i); } for (std::size_t k = 0; k < std::min(num, m_nodePtrs.size()); k++) { indices.push_back(pq.top()); pq.pop(); } return indices; } protected: /// list of node pointers to facilitate nearest neighbor search std::vector*> m_nodePtrs; /// states from all the nodes known by NN flattened into a 1-d vector std::vector m_currentNodeStates; }; } #endif