DYT/Tool/matlab/include/nav/planningcodegen_SearchTree.hpp
2024-11-22 23:19:31 +08:00

457 lines
13 KiB
C++

// Copyright 2019-2021 The MathWorks, Inc.
/**
* @file
* @brief Search tree data structure to support RRT-like planning algorithms
*/
#ifndef PLANNINGCODEGEN_SEARCHTREE_HPP
#define PLANNINGCODEGEN_SEARCHTREE_HPP
#include <vector>
#include <stack>
#include <deque>
#include <queue>
#include <iostream>
#include <string>
#include <cstddef>
#include <limits>
#include "planningcodegen_TreeNode.hpp"
#include "planningcodegen_DistanceMetric.hpp"
#include "planningcodegen_NearestNeighbor.hpp"
namespace nav {
/// template class that represents the underlying tree data structure for RRT-like planning
/// algorithms
/**
* @tparam T Data type
*/
template <typename T>
class SearchTree {
public:
template <typename Derived>
class Iterator {
public:
Iterator(TreeNode<T>* ptr) {
m_ptr = ptr;
}
/// copy constructor
Iterator(const Iterator& it) {
m_ptr = it.m_ptr;
}
virtual ~Iterator() {
}
virtual Derived& operator++(int32_T x) = 0;
TreeNode<T>* operator->() {
return m_ptr;
}
TreeNode<T>& operator*() {
return *m_ptr;
}
TreeNode<T>*& operator()() {
return m_ptr;
}
virtual boolean_T isAtEnd() = 0;
protected:
TreeNode<T>* m_ptr;
};
/// Breadth-first iterator
class BFSIterator : public Iterator<BFSIterator> {
public:
BFSIterator(TreeNode<T>* ptr)
: Iterator<BFSIterator>(ptr) {
m_queue.push_front(ptr);
}
BFSIterator(const BFSIterator& bit)
: Iterator<BFSIterator>(bit) {
m_queue = bit.m_queue;
}
/// postfix ++ breadth-first search iterator
BFSIterator& operator++(int32_T x) override {
// Avoid unused parameter compiler warning
(void)x;
auto ptr = m_queue.front();
m_queue.pop_front();
for (auto& p : ptr->m_children) { //
m_queue.push_back(p);
}
if (!m_queue.empty()) {
this->m_ptr = m_queue.front();
}
return *this;
}
boolean_T isAtEnd() override {
return m_queue.empty();
}
protected:
std::deque<TreeNode<T>*> m_queue;
};
/// Depth-first iterator
class DFSIterator : public Iterator<DFSIterator> {
public:
DFSIterator(TreeNode<T>* ptr)
: Iterator<DFSIterator>(ptr) {
m_stack.push(ptr);
}
DFSIterator(const DFSIterator& dit)
: Iterator<DFSIterator>(dit) {
m_stack = dit.m_stack;
}
/// postfix ++ depth-first search iterator
DFSIterator& operator++(int32_T x) override {
// Avoid unused parameter compiler warning
(void)x;
auto ptr = m_stack.top();
m_stack.pop();
for (auto& p : ptr->m_children) { //
m_stack.push(p);
}
if (!m_stack.empty()) {
this->m_ptr = m_stack.top();
}
return *this;
}
boolean_T isAtEnd() override {
return m_stack.empty();
}
protected:
std::stack<TreeNode<T>*> m_stack;
};
protected:
std::size_t m_nodeDim;
TreeNode<T>* m_root;
NearestNeighborFinder<T>* m_nnFinder;
std::vector<TreeNode<T>*> m_nodePtrs;
T m_ballRadiusConstant;
T m_maxConnectionDistance;
public:
SearchTree(const std::vector<T>& state) {
m_nodeDim = state.size();
m_ballRadiusConstant = static_cast<T>(1);
m_maxConnectionDistance = static_cast<T>(0.1);
m_root = new TreeNode<T>(m_nodeDim);
m_root->setState(state);
m_nnFinder = new ExhaustiveNN<T>(m_nodeDim);
m_nnFinder->insert(m_root);
m_nodePtrs.push_back(m_root);
m_root->setNodeID(0);
}
/// destructor
~SearchTree() {
std::stack<TreeNode<T>*> stackOfNodePtrs;
for (BFSIterator it(getRoot()); !it.isAtEnd(); it++) {
stackOfNodePtrs.push(it()); // i.e. &*it
}
while (!stackOfNodePtrs.empty()) {
delete stackOfNodePtrs.top();
stackOfNodePtrs.pop();
}
delete m_nnFinder;
}
void setBallRadiusConstant(T rc) {
m_ballRadiusConstant = rc;
}
void setMaxConnectionDistance(T dist) {
m_maxConnectionDistance = dist;
}
T getBallRadiusConstant() {
return m_ballRadiusConstant;
}
T getMaxConnectionDistance() {
return m_maxConnectionDistance;
}
TreeNode<T>* getRoot() {
return m_root;
}
TreeNode<T>* getNode(std::size_t idx) {
if (idx < m_nodePtrs.size()) {
return m_nodePtrs[idx];
} else {
return nullptr;
}
}
std::size_t getNodeDim() {
return m_nodeDim;
}
std::size_t getNumNodes() {
return m_nodePtrs.size();
}
NearestNeighborFinder<T>* getNNFinder() {
return m_nnFinder;
}
TreeNode<T>* insertNode(TreeNode<T>* parent, const std::vector<T>& state) {
std::vector<T> dist = m_nnFinder->m_queryMetric->distance(
parent->getState(), state); // compute distance first, if it throws exception, the rest
// will not be executed.
T costFromParent = dist[0];
auto pNode = new TreeNode<T>(m_nodeDim);
parent->addToChildren(pNode);
pNode->setState(state);
m_nodePtrs.push_back(pNode);
pNode->setNodeID(m_nodePtrs.size() - 1);
m_nnFinder->insert(pNode);
pNode->m_costFromParent = costFromParent;
pNode->m_costFromRoot = parent->m_costFromRoot + pNode->m_costFromParent;
return pNode;
}
boolean_T insertNodeByID(std::size_t idx, const std::vector<T>& state, std::size_t& idxNew) {
idxNew = 0;
if (idx < m_nodePtrs.size() && idx >= 0) {
auto parent = m_nodePtrs[idx];
if (state.size() != m_nodeDim) { // mismatched state size
return false;
}
std::vector<T> dist = m_nnFinder->m_queryMetric->distance(
parent->getState(), state); // compute distance first, if it throws exception, the
// rest will not be executed.
T costFromParent = dist[0];
auto pNode = new TreeNode<T>(m_nodeDim);
parent->addToChildren(pNode);
pNode->setState(state);
m_nodePtrs.push_back(pNode);
pNode->setNodeID(m_nodePtrs.size() - 1);
m_nnFinder->insert(pNode);
pNode->m_costFromParent = costFromParent;
pNode->m_costFromRoot = parent->m_costFromRoot + pNode->m_costFromParent;
idxNew = pNode->getNodeID();
return true;
} else {
return false;
}
}
boolean_T insertNodeByIDWithPrecomputedCost(std::size_t idx,
const std::vector<T>& state,
T precomputedCost,
std::size_t& idxNew) {
idxNew = 0;
if (idx < m_nodePtrs.size() && idx >= 0) {
auto parent = m_nodePtrs[idx];
T costFromParent = precomputedCost;
auto pNode = new TreeNode<T>(m_nodeDim);
parent->addToChildren(pNode);
pNode->setState(state);
m_nodePtrs.push_back(pNode);
pNode->setNodeID(m_nodePtrs.size() - 1);
m_nnFinder->insert(pNode);
pNode->m_costFromParent = costFromParent;
pNode->m_costFromRoot = parent->m_costFromRoot + pNode->m_costFromParent;
idxNew = pNode->getNodeID();
return true;
} else {
return false;
}
}
/// rewire one node and its descendants under a new parent node
int32_T rewireNodeByID(std::size_t nodeID,
std::size_t newParentNodeID,
T distanceBetweenNodes = -99) {
if (nodeID < m_nodePtrs.size() && newParentNodeID < m_nodePtrs.size()) {
auto pNode = m_nodePtrs[nodeID];
auto pNewParentNode = m_nodePtrs[newParentNodeID];
// if node's parent is newParentNode, do nothing
if (pNode->m_parent == pNewParentNode) {
return 2;
}
// if node is newParentNode
if (pNode == pNewParentNode) {
return 3;
}
// if newParentNode is in fact a descendant of node, abort
if (isDescendantOf(pNewParentNode, pNode)) {
return 4;
}
// normal case: node's parent is not newParentNode, and newParentNode is not a
// descendant of node.
pNode->m_parent->removeChild(pNode);
pNewParentNode->addToChildren(pNode);
// propagate cost changes downstream
T newDist = static_cast<T>(0.0);
if (distanceBetweenNodes < 0) {
newDist = computeDistanceBetweenNodes(pNewParentNode, pNode);
} else {
newDist = distanceBetweenNodes;
}
auto distDiff = pNewParentNode->m_costFromRoot + newDist - pNode->m_costFromRoot;
pNode->m_costFromParent = newDist;
for (BFSIterator it(pNode); !it.isAtEnd(); it++) {
it->m_costFromRoot += distDiff;
}
return 0; // success
} else {
return 1; // node ID out of range
}
}
/// check if node 1 is a descendant of node 2
boolean_T isDescendantOf(TreeNode<T>* pNode1, TreeNode<T>* pNode2) {
boolean_T result = false;
TreeNode<T>* ptr = pNode1;
if (ptr->m_parent == nullptr) { // if pNode1 is already the root
return false;
}
while (ptr != nullptr) {
if (ptr->m_parent == pNode2) {
result = true;
break;
}
ptr = ptr->m_parent;
}
return result;
}
/// trace back to root node from a given node id
std::vector<T> tracebackToRoot(std::size_t idx) {
std::vector<T> route;
std::vector<T> currState;
if (idx < m_nodePtrs.size() && idx >= 0) {
TreeNode<T>* currentNode = m_nodePtrs[idx];
currState = currentNode->getState();
route.insert(route.end(), currState.begin(), currState.end());
while (currentNode->m_parent) {
currentNode = currentNode->m_parent;
currState = currentNode->getState();
route.insert(route.end(), currState.begin(), currState.end());
}
}
return route;
}
T computeDistanceBetweenNodes(TreeNode<T>* node1, TreeNode<T>* node2) {
std::vector<T> dist =
m_nnFinder->m_queryMetric->distance(node1->getState(), node2->getState());
return dist[0];
}
std::size_t nearestNeighborID(const std::vector<T>& state, T& distNN) {
TreeNode<T>* nodePtr = m_nnFinder->nearestNeighbor(state, distNN);
return nodePtr->getNodeID();
}
std::vector<std::size_t> nearKNeighborIDs(const std::vector<T>& state, const std::size_t& num) {
std::vector<std::size_t> indices;
std::vector<TreeNode<T>*> nodePtrs = m_nnFinder->nearK(state, num);
for (auto&& pn : nodePtrs) {
indices.push_back(pn->getNodeID());
}
return indices;
}
/// Returns the IDs of the nodes within the closed ball of radius r centered at the given state.
/// r is computed adaptively.
std::vector<std::size_t> nearNeighborIDs(const std::vector<T>& state) {
std::vector<std::size_t> indices;
std::vector<TreeNode<T>*> nodePtrs = m_nnFinder->near(state, computeBallRadius());
for (auto&& pn : nodePtrs) {
indices.push_back(pn->getNodeID());
}
return indices;
}
/// compute ball radius
T computeBallRadius() {
std::size_t numNodes = m_nodePtrs.size();
T d = std::pow(m_ballRadiusConstant * std::log(numNodes) / numNodes,
static_cast<T>(1.0) / m_nodeDim);
T radius = std::fmin(d, m_maxConnectionDistance);
return radius;
}
std::vector<T> inspect() {
std::vector<T> output;
std::vector<T> nanState(m_nodeDim, std::numeric_limits<T>::quiet_NaN());
for (BFSIterator it(getRoot()); !it.isAtEnd(); it++) {
if (it->m_parent) {
auto p = it->m_parent->getState();
output.insert(output.end(), p.begin(), p.end());
}
auto curr = it->getState();
output.insert(output.end(), curr.begin(), curr.end());
output.insert(output.end(), nanState.begin(), nanState.end());
}
return output;
}
};
} // namespace nav
#endif