DYTSrouce/Tool/matlab/include/nav/planningcodegen_DistanceMetric.hpp

90 lines
2.6 KiB
C++
Raw Normal View History

2024-11-22 15:19:31 +00:00
// Copyright 2019-2021 The MathWorks, Inc.
/**
* @file
* @brief Distance metric between two states
*/
#ifndef PLANNINGCODEGEN_DISTANCEMETRIC_HPP
#define PLANNINGCODEGEN_DISTANCEMETRIC_HPP
#include <vector>
#include <iostream>
#include <cmath>
namespace nav
{
template <typename T>
class DistanceMetric
{
public:
DistanceMetric() {}
virtual ~DistanceMetric() {}
virtual std::vector<T> distance(const std::vector<T>& treeStates, const std::vector<T>& queryState) = 0;
protected:
/// convert row major matrix to column major matrix
static std::vector<real64_T> rowMajorToColumnMajor(const std::vector<real64_T>& states, std::size_t dim) {
std::size_t numStates = states.size() / dim;
std::vector<real64_T> statesReordered(states.size(), 0.0);
for (std::size_t i = 0; i < dim; i++) {
for (std::size_t j = 0; j < numStates; j++) {
statesReordered[j + i * numStates] = states[i + j * dim];
}
}
return statesReordered;
}
protected:
/// dimension of the state
std::size_t m_dim;
/// direction for computing distance
/**
* from treeStates to queryState [false, outward] or the other way around [true, inward]
*/
boolean_T m_reverseDirection;
};
template <typename T>
class EuclideanMetric : public DistanceMetric<T>
{
public:
EuclideanMetric(std::size_t dim)
{
this->m_dim = dim;
}
std::vector<T> distance(const std::vector<T>& states, const std::vector<T>& queryState)
{
std::vector<T> dists;
std::size_t numStates = states.size() / this->m_dim;
typename std::vector<T>::const_iterator queryIt = queryState.begin();
typename std::vector<T>::const_iterator statesIt = states.begin();
for (std::size_t i = 0; i < numStates; i++)
{
dists.push_back(distanceInternal(statesIt, queryIt));
statesIt += this->m_dim;
}
return dists;
}
protected:
T distanceInternal(const typename std::vector<T>::const_iterator& state1It, const typename std::vector<T>::const_iterator& state2It)
{
T sum = static_cast<T>(0.0);
for (size_t i = 0; i < this->m_dim; i++)
{
sum += ( *(state2It + i) - *(state1It + i)) * ( *(state2It + i) - *(state1It + i));
}
return std::sqrt(sum);
}
};
}
#endif