10584 lines
474 KiB
C++
10584 lines
474 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright 2016-2020 Intel Corporation
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*******************************************************************************/
|
||
|
|
||
|
/// @file
|
||
|
/// C++ API
|
||
|
|
||
|
#ifndef DNNL_HPP
|
||
|
#define DNNL_HPP
|
||
|
|
||
|
#include "dnnl_config.h"
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
#include <algorithm>
|
||
|
#include <cstdlib>
|
||
|
#include <iterator>
|
||
|
#include <memory>
|
||
|
#include <string>
|
||
|
#include <vector>
|
||
|
#include <unordered_map>
|
||
|
|
||
|
#include "dnnl.h"
|
||
|
|
||
|
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||
|
#include "dnnl_threadpool_iface.hpp"
|
||
|
#endif
|
||
|
|
||
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
|
||
|
#include <CL/cl.h>
|
||
|
#endif
|
||
|
/// @endcond
|
||
|
|
||
|
// __cpp_exceptions is referred from
|
||
|
// https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
|
||
|
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
|
||
|
// Microsoft C++ Compiler does not provide an option to disable exceptions
|
||
|
#ifndef DNNL_ENABLE_EXCEPTIONS
|
||
|
#if __cpp_exceptions || __EXCEPTIONS \
|
||
|
|| (defined(_MSC_VER) && !defined(__clang__))
|
||
|
#define DNNL_ENABLE_EXCEPTIONS 1
|
||
|
#else
|
||
|
#define DNNL_ENABLE_EXCEPTIONS 0
|
||
|
#endif
|
||
|
#endif
|
||
|
|
||
|
#if defined(__GNUC__) || defined(__clang__)
|
||
|
#define DNNL_TRAP() __builtin_trap()
|
||
|
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
|
||
|
#define DNNL_TRAP() __debugbreak()
|
||
|
#else
|
||
|
#error "unknown compiler"
|
||
|
#endif
|
||
|
|
||
|
#if DNNL_ENABLE_EXCEPTIONS
|
||
|
#define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
|
||
|
#else
|
||
|
#include <cstdio>
|
||
|
#define DNNL_THROW_ERROR(status, msg) \
|
||
|
do { \
|
||
|
fputs(msg, stderr); \
|
||
|
DNNL_TRAP(); \
|
||
|
} while (0)
|
||
|
#endif
|
||
|
|
||
|
/// @addtogroup dnnl_api oneDNN API
|
||
|
/// @{
|
||
|
|
||
|
/// oneDNN namespace
|
||
|
namespace dnnl {
|
||
|
|
||
|
/// @addtogroup dnnl_api_utils Utilities
|
||
|
/// Utility types and definitions.
|
||
|
/// @{
|
||
|
|
||
|
/// oneDNN exception class.
|
||
|
///
|
||
|
/// This class captures the status returned by a failed C API function and
|
||
|
/// the error message from the call site.
|
||
|
struct error : public std::exception {
|
||
|
dnnl_status_t status;
|
||
|
const char *message;
|
||
|
|
||
|
/// Constructs an instance of an exception class.
|
||
|
///
|
||
|
/// @param status The error status returned by a C API function.
|
||
|
/// @param message The error message.
|
||
|
error(dnnl_status_t status, const char *message)
|
||
|
: status(status), message(message) {}
|
||
|
|
||
|
/// Returns the explanatory string.
|
||
|
const char *what() const noexcept override { return message; }
|
||
|
|
||
|
/// A convenience function for wrapping calls to C API functions. Checks
|
||
|
/// the return status and throws an dnnl::error in case of failure.
|
||
|
///
|
||
|
/// @param status The error status returned by a C API function.
|
||
|
/// @param message The error message.
|
||
|
static void wrap_c_api(dnnl_status_t status, const char *message) {
|
||
|
if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
template <typename T>
|
||
|
void validate_container_size(const T &v, const char *error_message,
|
||
|
int min_size = 1, int max_size = -1) {
|
||
|
const int size = (int)v.size();
|
||
|
if (size < min_size || (max_size >= 0 && size > max_size))
|
||
|
DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
|
||
|
}
|
||
|
/// @endcond
|
||
|
|
||
|
/// A class that provides the destructor for a oneDNN C API handle.
|
||
|
template <typename T>
|
||
|
struct handle_traits {};
|
||
|
|
||
|
/// oneDNN C API handle wrapper class.
|
||
|
///
|
||
|
/// This class is used as the base class for primitive (dnnl::primitive),
|
||
|
/// engine (dnnl::engine), and stream (dnnl::stream) classes, as well as
|
||
|
/// others. An object of the dnnl::handle class can be passed by value.
|
||
|
///
|
||
|
/// A handle can be weak, in which case it follows std::weak_ptr semantics.
|
||
|
/// Otherwise, it follows `std::shared_ptr` semantics.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The implementation stores oneDNN C API handles in a `std::shared_ptr`
|
||
|
/// with deleter set to a dummy function in the weak mode.
|
||
|
///
|
||
|
template <typename T, typename traits = handle_traits<T>>
|
||
|
struct handle {
|
||
|
private:
|
||
|
static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
|
||
|
std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
|
||
|
|
||
|
protected:
|
||
|
bool operator==(const T other) const { return other == data_.get(); }
|
||
|
bool operator!=(const T other) const { return !(*this == other); }
|
||
|
|
||
|
public:
|
||
|
/// Constructs an empty handle object.
|
||
|
///
|
||
|
/// @warning
|
||
|
/// Uninitialized object cannot be used in most library calls and is
|
||
|
/// equivalent to a null pointer. Any attempt to use its methods, or
|
||
|
/// passing it to the other library function, will cause an exception
|
||
|
/// to be thrown.
|
||
|
handle() = default;
|
||
|
|
||
|
/// Copy constructor.
|
||
|
handle(const handle<T, traits> &) = default;
|
||
|
/// Assignment operator.
|
||
|
handle<T, traits> &operator=(const handle<T, traits> &) = default;
|
||
|
/// Move constructor.
|
||
|
handle(handle<T, traits> &&) = default;
|
||
|
/// Move assignment operator.
|
||
|
handle<T, traits> &operator=(handle<T, traits> &&) = default;
|
||
|
|
||
|
/// Constructs a handle wrapper object from a C API handle.
|
||
|
///
|
||
|
/// @param t The C API handle to wrap.
|
||
|
/// @param weak A flag specifying whether to construct a weak wrapper;
|
||
|
/// defaults to @c false.
|
||
|
explicit handle(T t, bool weak = false) { reset(t, weak); }
|
||
|
|
||
|
/// Resets the handle wrapper objects to wrap a new C API handle.
|
||
|
///
|
||
|
/// @param t The new value of the C API handle.
|
||
|
/// @param weak A flag specifying whether the wrapper should be weak;
|
||
|
/// defaults to @c false.
|
||
|
void reset(T t, bool weak = false) {
|
||
|
data_.reset(t, weak ? &dummy_destructor : traits::destructor);
|
||
|
}
|
||
|
|
||
|
/// Returns the underlying C API handle.
|
||
|
///
|
||
|
/// @param allow_empty A flag signifying whether the method is allowed to
|
||
|
/// return an empty (null) object without throwing an exception.
|
||
|
/// @returns The underlying C API handle.
|
||
|
T get(bool allow_empty = false) const {
|
||
|
T result = data_.get();
|
||
|
if (allow_empty == false && result == nullptr)
|
||
|
DNNL_THROW_ERROR(
|
||
|
dnnl_invalid_arguments, "object is not initialized");
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
/// Converts a handle to the underlying C API handle type. Does not throw
|
||
|
/// and returns `nullptr` if the object is empty.
|
||
|
///
|
||
|
/// @returns The underlying C API handle.
|
||
|
explicit operator T() const { return get(true); }
|
||
|
|
||
|
/// Checks whether the object is empty.
|
||
|
///
|
||
|
/// @returns Whether the object is empty.
|
||
|
explicit operator bool() const { return get(true) != nullptr; }
|
||
|
|
||
|
/// Equality operator.
|
||
|
///
|
||
|
/// @param other Another handle wrapper.
|
||
|
/// @returns @c true if this and the other handle wrapper manage the same
|
||
|
/// underlying C API handle, and @c false otherwise. Empty handle
|
||
|
/// objects are considered to be equal.
|
||
|
bool operator==(const handle<T, traits> &other) const {
|
||
|
return other.data_.get() == data_.get();
|
||
|
}
|
||
|
|
||
|
/// Inequality operator.
|
||
|
///
|
||
|
/// @param other Another handle wrapper.
|
||
|
/// @returns @c true if this and the other handle wrapper manage different
|
||
|
/// underlying C API handles, and @c false otherwise. Empty handle
|
||
|
/// objects are considered to be equal.
|
||
|
bool operator!=(const handle &other) const { return !(*this == other); }
|
||
|
};
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_memory_t> {
|
||
|
static dnnl_status_t destructor(dnnl_memory_t p) {
|
||
|
return dnnl_memory_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_primitive_desc_t> {
|
||
|
static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
|
||
|
return dnnl_primitive_desc_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_primitive_t> {
|
||
|
static dnnl_status_t destructor(dnnl_primitive_t p) {
|
||
|
return dnnl_primitive_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_primitive_desc_iterator_t> {
|
||
|
static dnnl_status_t destructor(dnnl_primitive_desc_iterator_t p) {
|
||
|
return dnnl_primitive_desc_iterator_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
/// @endcond
|
||
|
|
||
|
/// @} dnnl_api_utils
|
||
|
|
||
|
struct stream;
|
||
|
struct error;
|
||
|
struct memory;
|
||
|
struct primitive_desc;
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives Primitives
|
||
|
/// Compute primitives
|
||
|
/// @sa @ref dev_guide_basic_concepts
|
||
|
/// @{
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives_common Common
|
||
|
/// Common operations to create, destroy and inspect primitives
|
||
|
/// @{
|
||
|
|
||
|
/// Base class for all computational primitives.
|
||
|
struct primitive : public handle<dnnl_primitive_t> {
|
||
|
friend struct error;
|
||
|
friend struct stream;
|
||
|
|
||
|
/// Kinds of primitives supported by the library.
|
||
|
enum class kind {
|
||
|
/// Undefined primitive
|
||
|
undef = dnnl_undefined_primitive,
|
||
|
/// A reorder primitive.
|
||
|
reorder = dnnl_reorder,
|
||
|
/// A shuffle primitive.
|
||
|
shuffle = dnnl_shuffle,
|
||
|
/// A (out-of-place) tensor concatenation primitive.
|
||
|
concat = dnnl_concat,
|
||
|
/// A summation primitive.
|
||
|
sum = dnnl_sum,
|
||
|
/// A convolution primitive.
|
||
|
convolution = dnnl_convolution,
|
||
|
/// A deconvolution primitive.
|
||
|
deconvolution = dnnl_deconvolution,
|
||
|
/// An element-wise primitive.
|
||
|
eltwise = dnnl_eltwise,
|
||
|
/// A softmax primitive.
|
||
|
softmax = dnnl_softmax,
|
||
|
/// A pooling primitive.
|
||
|
pooling = dnnl_pooling,
|
||
|
/// An LRN primitive.
|
||
|
lrn = dnnl_lrn,
|
||
|
/// A batch normalization primitive.
|
||
|
batch_normalization = dnnl_batch_normalization,
|
||
|
/// A layer normalization primitive.
|
||
|
layer_normalization = dnnl_layer_normalization,
|
||
|
/// An inner product primitive.
|
||
|
inner_product = dnnl_inner_product,
|
||
|
/// A rnn primitive.
|
||
|
rnn = dnnl_rnn,
|
||
|
/// A binary primitive.
|
||
|
binary = dnnl_binary,
|
||
|
/// A logsoftmax primitive.
|
||
|
logsoftmax = dnnl_logsoftmax,
|
||
|
/// A matmul (matrix multiplication) primitive.
|
||
|
matmul = dnnl_matmul,
|
||
|
/// A resampling primitive.
|
||
|
resampling = dnnl_resampling,
|
||
|
};
|
||
|
|
||
|
using handle::handle;
|
||
|
|
||
|
/// Default constructor. Constructs an empty object.
|
||
|
primitive() = default;
|
||
|
|
||
|
/// Constructs a primitive from a C API primitive descriptor.
|
||
|
///
|
||
|
/// @param c_pd C API primitive descriptor.
|
||
|
primitive(const_dnnl_primitive_desc_t c_pd);
|
||
|
|
||
|
/// Constructs a primitive from a primitive descriptor.
|
||
|
///
|
||
|
/// @param pd Primitive descriptor.
|
||
|
primitive(const primitive_desc &pd);
|
||
|
|
||
|
/// Returns the C API primitive descriptor of the underlying C API
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @returns The underlying C API primitive descriptor.
|
||
|
inline const_dnnl_primitive_desc_t get_primitive_desc() const;
|
||
|
|
||
|
/// Returns the kind of the primitive.
|
||
|
///
|
||
|
/// @returns The primitive kind.
|
||
|
inline kind get_kind() const;
|
||
|
|
||
|
/// Executes computations specified by the primitive in a specified stream.
|
||
|
///
|
||
|
/// Arguments are passed via an arguments map containing <index,
|
||
|
/// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
|
||
|
/// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
|
||
|
/// matching the one returned by
|
||
|
/// primitive_desc::query_md(#query::exec_arg_md, index) unless using
|
||
|
/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
|
||
|
///
|
||
|
/// @param stream Stream object. The stream must belong to the same engine
|
||
|
/// as the primitive.
|
||
|
/// @param args Arguments map.
|
||
|
void execute(const stream &stream,
|
||
|
const std::unordered_map<int, memory> &args) const;
|
||
|
};
|
||
|
|
||
|
/// Converts primitive kind enum value from C++ API to C API type.
|
||
|
///
|
||
|
/// @param kind C++ API primitive kind enum value.
|
||
|
/// @returns Corresponding C API primitive kind enum value.
|
||
|
inline dnnl_primitive_kind_t convert_to_c(primitive::kind kind) {
|
||
|
return static_cast<dnnl_primitive_kind_t>(kind);
|
||
|
}
|
||
|
|
||
|
const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
|
||
|
const_dnnl_primitive_desc_t pd;
|
||
|
error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
|
||
|
"could not get a primitive descriptor from a primitive");
|
||
|
return pd;
|
||
|
}
|
||
|
|
||
|
dnnl::primitive::kind primitive::get_kind() const {
|
||
|
const_dnnl_primitive_desc_t pd = get_primitive_desc();
|
||
|
// TODO (Roma): the code below is only needed because get_primitive_desc
|
||
|
// returns a C type.
|
||
|
dnnl_primitive_kind_t kind;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_query(
|
||
|
pd, dnnl_query_primitive_kind, 0, (void *)&kind),
|
||
|
"could not get a primitive kind from a primitive descriptor");
|
||
|
return static_cast<dnnl::primitive::kind>(kind);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_primitives_common
|
||
|
|
||
|
/// @addtogroup dnnl_api_attributes
|
||
|
///
|
||
|
/// A container for parameters that extend primitives behavior.
|
||
|
///
|
||
|
/// Attributes can also contain Post-ops, which are computations executed
|
||
|
/// after the primitive.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_attributes
|
||
|
/// @sa @ref dev_guide_attributes_post_ops
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Scratchpad mode
|
||
|
enum class scratchpad_mode {
|
||
|
/// The library manages the scratchpad allocation according to the policy
|
||
|
/// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
|
||
|
/// [build option](@ref dev_guide_build_options) (default).
|
||
|
///
|
||
|
/// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
|
||
|
/// scratchpad is common to all primitives to reduce the memory footprint.
|
||
|
/// This configuration comes with limited thread-safety properties, namely
|
||
|
/// primitives can be created and executed in parallel but cannot migrate
|
||
|
/// between threads (in other words, each primitive should be executed in
|
||
|
/// the same thread it was created in).
|
||
|
///
|
||
|
/// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
|
||
|
/// private to each primitive. The memory footprint is larger than when
|
||
|
/// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
|
||
|
/// created and run concurrently (the same primitive cannot be run
|
||
|
/// concurrently from two different threads though).
|
||
|
library = dnnl_scratchpad_mode_library,
|
||
|
/// The user manages the scratchpad allocation by querying and providing
|
||
|
/// the scratchpad memory to primitives. This mode is thread-safe as long
|
||
|
/// as the scratchpad buffers are not used concurrently by two primitive
|
||
|
/// executions.
|
||
|
user = dnnl_scratchpad_mode_user,
|
||
|
};
|
||
|
|
||
|
/// Converts a scratchpad mode enum value from C++ API to C API type.
|
||
|
///
|
||
|
/// @param mode C++ API scratchpad mode enum value.
|
||
|
/// @returns Corresponding C API scratchpad mode enum value.
|
||
|
inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
|
||
|
return static_cast<dnnl_scratchpad_mode_t>(mode);
|
||
|
}
|
||
|
|
||
|
/// Propagation kind.
|
||
|
enum class prop_kind {
|
||
|
/// Undefined propagation kind.
|
||
|
undef = dnnl_prop_kind_undef,
|
||
|
/// Forward data propagation (training mode). In this mode, primitives
|
||
|
/// perform computations necessary for subsequent backward propagation.
|
||
|
forward_training = dnnl_forward_training,
|
||
|
/// Forward data propagation (inference mode). In this mode, primitives
|
||
|
/// perform only computations that are necessary for inference and omit
|
||
|
/// computations that are necessary only for backward propagation.
|
||
|
forward_inference = dnnl_forward_inference,
|
||
|
/// Forward data propagation,
|
||
|
/// alias for #dnnl::prop_kind::forward_inference.
|
||
|
forward_scoring = dnnl_forward_scoring,
|
||
|
/// Forward data propagation,
|
||
|
/// alias for #dnnl::prop_kind::forward_training.
|
||
|
forward = dnnl_forward,
|
||
|
/// Backward propagation (with respect to all parameters).
|
||
|
backward = dnnl_backward,
|
||
|
/// Backward data propagation.
|
||
|
backward_data = dnnl_backward_data,
|
||
|
/// Backward weights propagation.
|
||
|
backward_weights = dnnl_backward_weights,
|
||
|
/// Backward bias propagation.
|
||
|
backward_bias = dnnl_backward_bias
|
||
|
};
|
||
|
|
||
|
/// Converts propagation kind enum value from C++ API to C API type.
|
||
|
///
|
||
|
/// @param kind C++ API propagation kind enum value.
|
||
|
/// @returns Corresponding C API propagation kind enum value.
|
||
|
inline dnnl_prop_kind_t convert_to_c(prop_kind kind) {
|
||
|
return static_cast<dnnl_prop_kind_t>(kind);
|
||
|
}
|
||
|
|
||
|
/// Kinds of algorithms.
|
||
|
enum class algorithm {
|
||
|
/// Undefined algorithm
|
||
|
undef = dnnl_alg_kind_undef,
|
||
|
/// Convolution algorithm that is chosen to be either direct or Winograd
|
||
|
/// automatically
|
||
|
convolution_auto = dnnl_convolution_auto,
|
||
|
/// Direct convolution
|
||
|
convolution_direct = dnnl_convolution_direct,
|
||
|
/// Winograd convolution
|
||
|
convolution_winograd = dnnl_convolution_winograd,
|
||
|
/// Direct deconvolution
|
||
|
deconvolution_direct = dnnl_deconvolution_direct,
|
||
|
/// Winograd deconvolution
|
||
|
deconvolution_winograd = dnnl_deconvolution_winograd,
|
||
|
/// Elementwise: rectified linear unit (ReLU)
|
||
|
eltwise_relu = dnnl_eltwise_relu,
|
||
|
/// Elementwise: hyperbolic tangent non-linearity (tanh)
|
||
|
eltwise_tanh = dnnl_eltwise_tanh,
|
||
|
/// Elementwise: exponential linear unit (ELU)
|
||
|
eltwise_elu = dnnl_eltwise_elu,
|
||
|
/// Elementwise: square
|
||
|
eltwise_square = dnnl_eltwise_square,
|
||
|
/// Elementwise: abs
|
||
|
eltwise_abs = dnnl_eltwise_abs,
|
||
|
/// Elementwise: square root
|
||
|
eltwise_sqrt = dnnl_eltwise_sqrt,
|
||
|
/// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
|
||
|
eltwise_swish = dnnl_eltwise_swish,
|
||
|
/// Elementwise: linear
|
||
|
eltwise_linear = dnnl_eltwise_linear,
|
||
|
/// Elementwise: bounded_relu
|
||
|
eltwise_bounded_relu = dnnl_eltwise_bounded_relu,
|
||
|
/// Elementwise: soft_relu
|
||
|
eltwise_soft_relu = dnnl_eltwise_soft_relu,
|
||
|
/// Elementwise: logistic
|
||
|
eltwise_logistic = dnnl_eltwise_logistic,
|
||
|
/// Elementwise: exponent
|
||
|
eltwise_exp = dnnl_eltwise_exp,
|
||
|
/// Elementwise: gelu
|
||
|
/// alias for #dnnl::algorithm::eltwise_gelu_tanh
|
||
|
eltwise_gelu = dnnl_eltwise_gelu,
|
||
|
/// Elementwise: tanh-based gelu
|
||
|
eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
|
||
|
/// Elementwise: erf-based gelu
|
||
|
eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
|
||
|
/// Elementwise: natural logarithm
|
||
|
eltwise_log = dnnl_eltwise_log,
|
||
|
/// Elementwise: clip
|
||
|
eltwise_clip = dnnl_eltwise_clip,
|
||
|
/// Elementwise: pow
|
||
|
eltwise_pow = dnnl_eltwise_pow,
|
||
|
/// Elementwise: rectified linar unit (ReLU) (dst for backward)
|
||
|
eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
|
||
|
/// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
|
||
|
eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
|
||
|
/// Elementwise: exponential linear unit (ELU) (dst for backward)
|
||
|
eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
|
||
|
/// Elementwise: square root (dst for backward)
|
||
|
eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
|
||
|
/// Elementwise: logistic (dst for backward)
|
||
|
eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
|
||
|
/// Elementwise: exponent (dst for backward)
|
||
|
eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
|
||
|
/// Local response normalization (LRN) across multiple channels
|
||
|
lrn_across_channels = dnnl_lrn_across_channels,
|
||
|
/// LRN within a single channel
|
||
|
lrn_within_channel = dnnl_lrn_within_channel,
|
||
|
/// Max pooling
|
||
|
pooling_max = dnnl_pooling_max,
|
||
|
/// Average pooling exclude padding,
|
||
|
/// alias for #dnnl::algorithm::pooling_avg_include_padding
|
||
|
pooling_avg = dnnl_pooling_avg,
|
||
|
/// Average pooling include padding
|
||
|
pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
|
||
|
/// Average pooling exclude padding
|
||
|
pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
|
||
|
/// RNN cell
|
||
|
vanilla_rnn = dnnl_vanilla_rnn,
|
||
|
/// LSTM cell
|
||
|
vanilla_lstm = dnnl_vanilla_lstm,
|
||
|
/// GRU cell
|
||
|
vanilla_gru = dnnl_vanilla_gru,
|
||
|
/// GRU cell with linear before reset. Differs from the vanilla GRU
|
||
|
/// in how the new memory gate is calculated:
|
||
|
/// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
|
||
|
/// LRB GRU expects 4 bias tensors on input:
|
||
|
/// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
|
||
|
lbr_gru = dnnl_lbr_gru,
|
||
|
/// Binary add
|
||
|
binary_add = dnnl_binary_add,
|
||
|
/// Binary mul
|
||
|
binary_mul = dnnl_binary_mul,
|
||
|
/// Binary max
|
||
|
binary_max = dnnl_binary_max,
|
||
|
/// Binary min
|
||
|
binary_min = dnnl_binary_min,
|
||
|
/// Nearest Neighbor resampling method
|
||
|
resampling_nearest = dnnl_resampling_nearest,
|
||
|
/// Linear (Bilinear, Trilinear) resampling method
|
||
|
resampling_linear = dnnl_resampling_linear,
|
||
|
};
|
||
|
|
||
|
/// Converts algorithm kind enum value from C++ API to C API type.
|
||
|
/// @param algorithm C++ API algorithm kind enum value.
|
||
|
/// @returns Corresponding C API algorithm kind enum value.
|
||
|
inline dnnl_alg_kind_t convert_to_c(algorithm algorithm) {
|
||
|
return static_cast<dnnl_alg_kind_t>(algorithm);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_attributes
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives_common
|
||
|
/// @{
|
||
|
|
||
|
/// Flags for normalization primitives.
|
||
|
enum class normalization_flags : unsigned {
|
||
|
/// Use no normalization flags. If specified, the library computes mean and
|
||
|
/// variance on forward propagation for training and inference, outputs them
|
||
|
/// on forward propagation for training, and computes the respective
|
||
|
/// derivatives on backward propagation.
|
||
|
none = dnnl_normalization_flags_none,
|
||
|
|
||
|
/// Use global statistics. If specified, the library uses mean and
|
||
|
/// variance provided by the user as an input on forward propagation and
|
||
|
/// does not compute their derivatives on backward propagation. Otherwise,
|
||
|
/// the library computes mean and variance on forward propagation for
|
||
|
/// training and inference, outputs them on forward propagation for
|
||
|
/// training, and computes the respective derivatives on backward
|
||
|
/// propagation.
|
||
|
use_global_stats = dnnl_use_global_stats,
|
||
|
|
||
|
/// Use scale and shift parameters. If specified, the user is expected to
|
||
|
/// pass scale and shift as inputs on forward propagation. On backward
|
||
|
/// propagation of type #dnnl::prop_kind::backward, the library computes
|
||
|
/// their derivatives. If not specified, the scale and shift parameters
|
||
|
/// are not used by the library in any way.
|
||
|
use_scale_shift = dnnl_use_scaleshift,
|
||
|
|
||
|
/// Fuse normalization with ReLU. On training, normalization will require
|
||
|
/// the workspace to implement backward propagation. On inference, the
|
||
|
/// workspace is not required and behavior is the same as when normalization
|
||
|
/// is fused with ReLU using the post-ops API.
|
||
|
fuse_norm_relu = dnnl_fuse_norm_relu
|
||
|
};
|
||
|
|
||
|
/// Converts normalization flags enum value from C++ API to C API type.
|
||
|
/// @param flags C++ API normalization flags enum value.
|
||
|
/// @returns Corresponding C API normalization flags enum value.
|
||
|
inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
|
||
|
return static_cast<dnnl_normalization_flags_t>(flags);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_primitives_common
|
||
|
|
||
|
/// @addtogroup dnnl_api_rnn
|
||
|
/// @{
|
||
|
|
||
|
/// RNN cell flags.
|
||
|
enum class rnn_flags : unsigned {
|
||
|
/// Undefined RNN flags
|
||
|
undef = dnnl_rnn_flags_undef
|
||
|
};
|
||
|
|
||
|
/// Converts RNN cell flags enum value from C++ API to C API type.
|
||
|
/// @param flags C++ API RNN cell flags enum value.
|
||
|
/// @returns Corresponding C API RNN cell flags enum value.
|
||
|
inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
|
||
|
return static_cast<dnnl_rnn_flags_t>(flags);
|
||
|
}
|
||
|
|
||
|
#define DNNL_DEFINE_BITMASK_OPS(enum_name) \
|
||
|
inline enum_name operator|(enum_name lhs, enum_name rhs) { \
|
||
|
return static_cast<enum_name>( \
|
||
|
static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
|
||
|
} \
|
||
|
\
|
||
|
inline enum_name operator&(enum_name lhs, enum_name rhs) { \
|
||
|
return static_cast<enum_name>( \
|
||
|
static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
|
||
|
} \
|
||
|
\
|
||
|
inline enum_name operator^(enum_name lhs, enum_name rhs) { \
|
||
|
return static_cast<enum_name>( \
|
||
|
static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
|
||
|
} \
|
||
|
\
|
||
|
inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
|
||
|
lhs = static_cast<enum_name>( \
|
||
|
static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
|
||
|
return lhs; \
|
||
|
} \
|
||
|
\
|
||
|
inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
|
||
|
lhs = static_cast<enum_name>( \
|
||
|
static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
|
||
|
return lhs; \
|
||
|
} \
|
||
|
\
|
||
|
inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
|
||
|
lhs = static_cast<enum_name>( \
|
||
|
static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
|
||
|
return lhs; \
|
||
|
} \
|
||
|
\
|
||
|
inline enum_name operator~(enum_name rhs) { \
|
||
|
return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
|
||
|
}
|
||
|
|
||
|
DNNL_DEFINE_BITMASK_OPS(normalization_flags)
|
||
|
DNNL_DEFINE_BITMASK_OPS(rnn_flags)
|
||
|
|
||
|
/// A direction of RNN primitive execution
|
||
|
enum class rnn_direction {
|
||
|
/// Unidirectional execution of RNN primitive from left to right.
|
||
|
unidirectional_left2right = dnnl_unidirectional_left2right,
|
||
|
/// Unidirectional execution of RNN primitive from right to left.
|
||
|
unidirectional_right2left = dnnl_unidirectional_right2left,
|
||
|
/// Bidirectional execution of RNN primitive with concatenation of the
|
||
|
/// results.
|
||
|
bidirectional_concat = dnnl_bidirectional_concat,
|
||
|
/// Bidirectional execution of RNN primitive with summation of the
|
||
|
/// results.
|
||
|
bidirectional_sum = dnnl_bidirectional_sum,
|
||
|
/// Alias for #dnnl::rnn_direction::unidirectional_left2right
|
||
|
unidirectional = dnnl_unidirectional,
|
||
|
};
|
||
|
|
||
|
/// Converts RNN direction enum value from C++ API to C API type.
|
||
|
/// @param dir C++ API RNN direction enum value.
|
||
|
/// @returns Corresponding C API RNN direction enum value.
|
||
|
inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
|
||
|
return static_cast<dnnl_rnn_direction_t>(dir);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_rnn
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives_common
|
||
|
/// @{
|
||
|
|
||
|
/// Primitive descriptor query specification.
|
||
|
///
|
||
|
/// In general, queries are not used with the C++ API because most queries are
|
||
|
/// implemented as class members.
|
||
|
///
|
||
|
/// See @ref dnnl_query_t for more information.
|
||
|
enum class query {
|
||
|
/// no query
|
||
|
undef = dnnl_query_undef,
|
||
|
|
||
|
/// execution engine
|
||
|
engine = dnnl_query_engine,
|
||
|
/// primitive kind
|
||
|
primitive_kind = dnnl_query_primitive_kind,
|
||
|
|
||
|
/// number of inputs expected
|
||
|
num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
|
||
|
/// number of outputs expected
|
||
|
num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
|
||
|
|
||
|
/// runtime estimation (seconds), unimplemented
|
||
|
time_estimate_f64 = dnnl_query_time_estimate_f64,
|
||
|
/// memory consumption (bytes)
|
||
|
///
|
||
|
/// extra (scratch) memory, additional to all inputs and outputs memory
|
||
|
///
|
||
|
/// @sa @ref dev_guide_attributes_scratchpad
|
||
|
memory_consumption_s64 = dnnl_query_memory_consumption_s64,
|
||
|
|
||
|
/// scratchpad engine
|
||
|
///
|
||
|
/// engine to be used for creating scratchpad memory
|
||
|
scratchpad_engine = dnnl_query_scratchpad_engine,
|
||
|
|
||
|
/// reorder source engine
|
||
|
reorder_src_engine = dnnl_query_reorder_src_engine,
|
||
|
/// reorder destination engine
|
||
|
reorder_dst_engine = dnnl_query_reorder_dst_engine,
|
||
|
|
||
|
/// implementation name
|
||
|
impl_info_str = dnnl_query_impl_info_str,
|
||
|
|
||
|
/// propagation kind
|
||
|
prop_kind = dnnl_query_prop_kind,
|
||
|
|
||
|
/// operation descriptor
|
||
|
op_d = dnnl_query_op_d,
|
||
|
/// convolution descriptor
|
||
|
convolution_d = dnnl_query_convolution_d,
|
||
|
/// deconvolution descriptor
|
||
|
deconvolution_d = dnnl_query_deconvolution_d,
|
||
|
/// shuffle descriptor
|
||
|
shuffle_d = dnnl_query_shuffle_d,
|
||
|
/// eltwise descriptor
|
||
|
eltwise_d = dnnl_query_eltwise_d,
|
||
|
/// softmax descriptor
|
||
|
softmax_d = dnnl_query_softmax_d,
|
||
|
/// pooling descriptor
|
||
|
pooling_d = dnnl_query_pooling_d,
|
||
|
/// lrn descriptor
|
||
|
lrn_d = dnnl_query_lrn_d,
|
||
|
/// batch normalization descriptor
|
||
|
batch_normalization_d = dnnl_query_batch_normalization_d,
|
||
|
/// layer normalization descriptor
|
||
|
layer_normalization_d = dnnl_query_layer_normalization_d,
|
||
|
/// inner product descriptor
|
||
|
inner_product_d = dnnl_query_inner_product_d,
|
||
|
/// rnn descriptor
|
||
|
rnn_d = dnnl_query_rnn_d,
|
||
|
/// binary descriptor
|
||
|
binary_d = dnnl_query_binary_d,
|
||
|
/// logsoftmax descriptor
|
||
|
logsoftmax_d = dnnl_query_logsoftmax_d,
|
||
|
/// matmul descriptor
|
||
|
matmul_d = dnnl_query_matmul_d,
|
||
|
/// resampling descriptor
|
||
|
resampling_d = dnnl_query_resampling_d,
|
||
|
|
||
|
/// source memory desc
|
||
|
src_md = dnnl_query_src_md,
|
||
|
/// source gradient (diff) memory desc
|
||
|
diff_src_md = dnnl_query_diff_src_md,
|
||
|
/// weights memory descriptor desc
|
||
|
weights_md = dnnl_query_weights_md,
|
||
|
/// weights gradient (diff) memory desc
|
||
|
diff_weights_md = dnnl_query_diff_weights_md,
|
||
|
/// destination memory desc
|
||
|
dst_md = dnnl_query_dst_md,
|
||
|
/// destination gradient (diff) memory desc
|
||
|
diff_dst_md = dnnl_query_diff_dst_md,
|
||
|
/// workspace memory desc
|
||
|
workspace_md = dnnl_query_workspace_md,
|
||
|
/// scratchpad memory desc
|
||
|
scratchpad_md = dnnl_query_scratchpad_md,
|
||
|
/// memory desc of an execute argument
|
||
|
exec_arg_md = dnnl_query_exec_arg_md,
|
||
|
};
|
||
|
|
||
|
/// Converts query enum value from C++ API to C API type.
|
||
|
/// @param query C++ API query enum value.
|
||
|
/// @returns Corresponding C API query enum value.
|
||
|
inline dnnl_query_t convert_to_c(query query) {
|
||
|
return static_cast<dnnl_query_t>(query);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_primitives_common
|
||
|
|
||
|
/// @} dnnl_api_primitives
|
||
|
|
||
|
/// @addtogroup dnnl_api_engine Engine
|
||
|
///
|
||
|
/// An abstraction of a computational device: a CPU, a specific GPU
|
||
|
/// card in the system, etc. Most primitives are created to execute
|
||
|
/// computations on one specific engine. The only exceptions are reorder
|
||
|
/// primitives that transfer data between two different engines.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_basic_concepts
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_engine_t> {
|
||
|
static dnnl_status_t destructor(dnnl_engine_t p) {
|
||
|
return dnnl_engine_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
/// @endcond
|
||
|
|
||
|
/// An execution engine.
|
||
|
struct engine : public handle<dnnl_engine_t> {
|
||
|
friend struct primitive;
|
||
|
friend struct reorder;
|
||
|
|
||
|
/// Kinds of engines.
|
||
|
enum class kind {
|
||
|
/// An unspecified engine
|
||
|
any = dnnl_any_engine,
|
||
|
/// CPU engine
|
||
|
cpu = dnnl_cpu,
|
||
|
/// GPU engine
|
||
|
gpu = dnnl_gpu,
|
||
|
};
|
||
|
|
||
|
using handle::handle;
|
||
|
|
||
|
/// Constructs an empty engine. An empty engine cannot be used in any
|
||
|
/// operations.
|
||
|
engine() = default;
|
||
|
|
||
|
/// Returns the number of engines of a certain kind.
|
||
|
///
|
||
|
/// @param kind The kind of engines to count.
|
||
|
/// @returns The number of engines of the specified kind.
|
||
|
static size_t get_count(kind kind) {
|
||
|
return dnnl_engine_get_count(convert_to_c(kind));
|
||
|
}
|
||
|
|
||
|
/// Constructs an engine.
|
||
|
///
|
||
|
/// @param kind The kind of engine to construct.
|
||
|
/// @param index The index of the engine. Must be less than the value
|
||
|
/// returned by #get_count() for this particular kind of engine.
|
||
|
engine(kind kind, size_t index) {
|
||
|
dnnl_engine_t engine;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_engine_create(&engine, convert_to_c(kind), index),
|
||
|
"could not create an engine");
|
||
|
reset(engine);
|
||
|
}
|
||
|
|
||
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
|
||
|
/// Constructs an engine from OpenCL device and context objects.
|
||
|
///
|
||
|
/// @param kind The kind of engine to construct.
|
||
|
/// @param device The OpenCL device that this engine will encapsulate.
|
||
|
/// @param context The OpenCL context (containing the device) that this
|
||
|
/// engine will use for all operations.
|
||
|
engine(kind kind, cl_device_id device, cl_context context) {
|
||
|
dnnl_engine_t engine;
|
||
|
error::wrap_c_api(dnnl_engine_create_ocl(
|
||
|
&engine, convert_to_c(kind), device, context),
|
||
|
"could not create an engine");
|
||
|
reset(engine);
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
/// Constructs an engine based on a primitive from the primitive
|
||
|
/// descriptor @p pd by querying its engine.
|
||
|
///
|
||
|
/// @param pd The primitive descriptor to query.
|
||
|
engine(const handle<dnnl_primitive_desc_t> &pd) {
|
||
|
dnnl_engine_t c_engine;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_desc_query(pd.get(),
|
||
|
dnnl::convert_to_c(dnnl::query::engine), 0, &c_engine),
|
||
|
"could not get an engine from a primitive_desc");
|
||
|
reset(c_engine, true);
|
||
|
}
|
||
|
|
||
|
/// Returns the kind of the engine.
|
||
|
/// @returns The kind of the engine.
|
||
|
kind get_kind() const {
|
||
|
dnnl_engine_kind_t kind;
|
||
|
error::wrap_c_api(dnnl_engine_get_kind(get(), &kind),
|
||
|
"could not get kind of an engine");
|
||
|
return static_cast<engine::kind>(kind);
|
||
|
}
|
||
|
|
||
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
|
||
|
/// Returns the OpenCL context associated with the engine.
|
||
|
/// @returns OpenCL context.
|
||
|
cl_context get_ocl_context() const {
|
||
|
cl_context context = nullptr;
|
||
|
error::wrap_c_api(dnnl_engine_get_ocl_context(get(), &context),
|
||
|
"could not get an OpenCL context fron an engine");
|
||
|
return context;
|
||
|
}
|
||
|
|
||
|
/// Returns the OpenCL device associated with the engine.
|
||
|
/// @returns OpenCL device.
|
||
|
cl_device_id get_ocl_device() const {
|
||
|
cl_device_id device = nullptr;
|
||
|
error::wrap_c_api(dnnl_engine_get_ocl_device(get(), &device),
|
||
|
"could not get an OpenCL device fron an engine");
|
||
|
return device;
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
/// Returns the engine of a primitive descriptor.
|
||
|
///
|
||
|
/// @param pd The primitive descriptor to query.
|
||
|
/// @returns A weak handle to the engine that the primitive descriptor was
|
||
|
/// created with.
|
||
|
template <typename primitive_desc>
|
||
|
static engine query(const primitive_desc &pd) {
|
||
|
return query(pd, dnnl::query::engine);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
static dnnl_engine_kind_t convert_to_c(kind kind) {
|
||
|
return static_cast<dnnl_engine_kind_t>(kind);
|
||
|
}
|
||
|
|
||
|
template <typename primitive_desc>
|
||
|
static engine query(const primitive_desc &pd, dnnl::query what) {
|
||
|
dnnl_engine_t c_engine;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_query(pd.get(),
|
||
|
dnnl::convert_to_c(what), 0, &c_engine),
|
||
|
"could not get an engine from a primitive_desc");
|
||
|
return engine(c_engine, true);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Converts engine kind enum value from C++ API to C API type.
|
||
|
///
|
||
|
/// @param kind C++ API engine kind enum value.
|
||
|
/// @returns Corresponding C API engine kind enum value.
|
||
|
inline dnnl_engine_kind_t convert_to_c(engine::kind kind) {
|
||
|
return static_cast<dnnl_engine_kind_t>(kind);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_engine
|
||
|
|
||
|
/// @addtogroup dnnl_api_stream Stream
|
||
|
///
|
||
|
/// An encapsulation of execution context tied to a particular engine.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_basic_concepts
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_stream_t> {
|
||
|
static dnnl_status_t destructor(dnnl_stream_t p) {
|
||
|
return dnnl_stream_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_stream_attr_t> {
|
||
|
static dnnl_status_t destructor(dnnl_stream_attr_t p) {
|
||
|
return dnnl_stream_attr_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
/// @endcond
|
||
|
|
||
|
/// A container for stream attributes.
|
||
|
struct stream_attr : public handle<dnnl_stream_attr_t> {
|
||
|
using handle::handle;
|
||
|
|
||
|
/// Constructs default (empty) stream attributes.
|
||
|
stream_attr() = default;
|
||
|
|
||
|
/// Constructs stream attributes for a stream that runs on an engine of a
|
||
|
/// particular kind.
|
||
|
///
|
||
|
/// @param kind Target engine kind.
|
||
|
stream_attr(engine::kind kind) {
|
||
|
dnnl_stream_attr_t attr;
|
||
|
error::wrap_c_api(dnnl_stream_attr_create(&attr, convert_to_c(kind)),
|
||
|
"could not create stream attributes");
|
||
|
reset(attr);
|
||
|
}
|
||
|
|
||
|
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||
|
/// Sets the threadpool attribute. Always throws unless oneDNN is built with
|
||
|
/// threadpool runtime.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_threadpool
|
||
|
///
|
||
|
/// @param threadpool A pointer to an instance of a class that implements
|
||
|
/// the dnnl::threadpool_iface interface.
|
||
|
void set_threadpool(threadpool_iface *threadpool) {
|
||
|
error::wrap_c_api(dnnl_stream_attr_set_threadpool(get(), threadpool),
|
||
|
"could not set stream threadpool attribute");
|
||
|
}
|
||
|
|
||
|
/// Returns the threadpool attribute. Always throws unless oneDNN is built
|
||
|
/// with threadpool runtime.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_threadpool
|
||
|
///
|
||
|
threadpool_iface *get_threadpool() {
|
||
|
threadpool_iface *tp;
|
||
|
error::wrap_c_api(dnnl_stream_attr_get_threadpool(get(), (void **)&tp),
|
||
|
"could not set stream threadpool attribute");
|
||
|
return tp;
|
||
|
}
|
||
|
#endif
|
||
|
};
|
||
|
|
||
|
/// An execution stream.
|
||
|
struct stream : public handle<dnnl_stream_t> {
|
||
|
using handle::handle;
|
||
|
|
||
|
/// Stream flags. Can be combined using the bitwise OR operator.
|
||
|
enum class flags : unsigned {
|
||
|
/// Default order execution. Either in-order or out-of-order depending
|
||
|
/// on the engine runtime.
|
||
|
default_order = dnnl_stream_default_order,
|
||
|
/// In-order execution.
|
||
|
in_order = dnnl_stream_default_order,
|
||
|
/// Out-of-order execution.
|
||
|
out_of_order = dnnl_stream_out_of_order,
|
||
|
/// Default stream configuration.
|
||
|
default_flags = dnnl_stream_default_flags,
|
||
|
};
|
||
|
|
||
|
/// Constructs an empty stream. An empty stream cannot be used in any
|
||
|
/// operations.
|
||
|
stream() = default;
|
||
|
|
||
|
/// Constructs a stream for the specified engine and with behavior
|
||
|
/// controlled by the specified flags.
|
||
|
///
|
||
|
/// @param engine Engine to create the stream on.
|
||
|
/// @param flags Flags controlling stream behavior.
|
||
|
/// @param attr Stream attributes.
|
||
|
stream(const engine &engine, flags flags = flags::default_flags,
|
||
|
const stream_attr &attr = stream_attr()) {
|
||
|
dnnl_stream_t stream;
|
||
|
error::wrap_c_api(dnnl_stream_create_v2(&stream, engine.get(),
|
||
|
static_cast<dnnl_stream_flags_t>(flags),
|
||
|
attr.get(true)),
|
||
|
"could not create a stream");
|
||
|
reset(stream);
|
||
|
}
|
||
|
|
||
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
|
||
|
/// Constructs a stream for the specified engine and the OpenCL queue.
|
||
|
///
|
||
|
/// @param engine Engine to create the stream on.
|
||
|
/// @param queue OpenCL queue to use for the stream.
|
||
|
stream(const engine &engine, cl_command_queue queue) {
|
||
|
dnnl_stream_t stream;
|
||
|
error::wrap_c_api(dnnl_stream_create_ocl(&stream, engine.get(), queue),
|
||
|
"could not create a stream");
|
||
|
reset(stream);
|
||
|
}
|
||
|
|
||
|
/// Returns the underlying OpenCL queue object.
|
||
|
/// @returns OpenCL queue.
|
||
|
cl_command_queue get_ocl_command_queue() const {
|
||
|
cl_command_queue queue = nullptr;
|
||
|
error::wrap_c_api(dnnl_stream_get_ocl_command_queue(get(), &queue),
|
||
|
"could not get an OpenCL command queue from a stream");
|
||
|
return queue;
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
/// Waits for all primitives executing in the stream to finish.
|
||
|
/// @returns The stream itself.
|
||
|
stream &wait() {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_stream_wait(get()), "could not wait on a stream");
|
||
|
return *this;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
DNNL_DEFINE_BITMASK_OPS(stream::flags)
|
||
|
|
||
|
/// @} dnnl_api_stream
|
||
|
|
||
|
/// @addtogroup dnnl_api_memory Memory
|
||
|
///
|
||
|
/// A container that describes and stores data. Memory objects can contain
|
||
|
/// data of various types and formats. There are two levels of abstraction:
|
||
|
///
|
||
|
/// 1. **Memory descriptor** -- engine-agnostic logical description of data
|
||
|
/// (number of dimensions, dimension sizes, and data type), and,
|
||
|
/// optionally, the information about the physical format of data in
|
||
|
/// memory. If this information is not known yet, a memory descriptor can
|
||
|
/// be created with #dnnl::memory::format_tag::any. This allows
|
||
|
/// compute-intensive primitives to choose the best format for
|
||
|
/// computation. The user is responsible for reordering the data into the
|
||
|
/// chosen format when formats do not match.
|
||
|
///
|
||
|
/// A memory descriptor can be initialized either by specifying dimensions
|
||
|
/// and a memory format tag or strides for each of them, or by
|
||
|
/// manipulating the dnnl_memory_desc_t structure directly.
|
||
|
///
|
||
|
/// @warning
|
||
|
/// The latter approach requires understanding how the physical data
|
||
|
/// representation is mapped to the structure and is discouraged. This
|
||
|
/// topic is discussed in @ref dev_guide_understanding_memory_formats.
|
||
|
///
|
||
|
/// The user can query the amount of memory required by a memory
|
||
|
/// descriptor using the #dnnl::memory::desc::get_size() function. The
|
||
|
/// size of data in general cannot be computed as the product of
|
||
|
/// dimensions multiplied by the size of the data type. So users are
|
||
|
/// required to use this function for better code portability.
|
||
|
///
|
||
|
/// Two memory descriptors can be compared using the equality and
|
||
|
/// inequality operators. The comparison is especially useful when
|
||
|
/// checking whether it is necessary to reorder data from the user's data
|
||
|
/// format to a primitive's format.
|
||
|
///
|
||
|
/// 2. **Memory object** -- an engine-specific object that handles the data
|
||
|
/// and its description (a memory descriptor). For the CPU engine, the
|
||
|
/// data handle is simply a pointer to @c void. The data handle can be
|
||
|
/// queried using #dnnl::memory::get_data_handle() and set using
|
||
|
/// #dnnl::memory::set_data_handle(). A memory object can also be
|
||
|
/// queried for the underlying memory descriptor and for its engine using
|
||
|
/// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
|
||
|
///
|
||
|
/// Along with ordinary memory descriptors with all dimensions being positive,
|
||
|
/// the library supports *zero-volume* memory descriptors with one or more
|
||
|
/// dimensions set to zero. This is used to support the NumPy\* convention.
|
||
|
/// If a zero-volume memory is passed to a primitive, the primitive typically
|
||
|
/// does not perform any computations with this memory. For example:
|
||
|
///
|
||
|
/// - A concatenation primitive would ignore all memory object with zeroes in
|
||
|
/// the concat dimension / axis.
|
||
|
///
|
||
|
/// - A forward convolution with a source memory object with zero in the
|
||
|
/// minibatch dimension would always produce a destination memory object
|
||
|
/// with a zero in the minibatch dimension and perform no computations.
|
||
|
///
|
||
|
/// - However, a forward convolution with a zero in one of the weights
|
||
|
/// dimensions is ill-defined and is considered to be an error by the
|
||
|
/// library because there is no clear definition of what the output values
|
||
|
/// should be.
|
||
|
///
|
||
|
/// Data handle of a zero-volume memory is never accessed.
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Memory object.
|
||
|
///
|
||
|
/// A memory object encapsulates a handle to a memory buffer allocated on a
|
||
|
/// specific engine, tensor dimensions, data type, and memory format, which is
|
||
|
/// the way tensor indices map to offsets in linear memory space. Memory
|
||
|
/// objects are passed to primitives during execution.
|
||
|
struct memory : public handle<dnnl_memory_t> {
|
||
|
/// Integer type for representing dimension sizes and indices.
|
||
|
typedef dnnl_dim_t dim;
|
||
|
/// Vector of dimensions. Implementations are free to force a limit on the
|
||
|
/// vector's length.
|
||
|
typedef std::vector<dim> dims;
|
||
|
|
||
|
/// Helper function that validates that an `std::vector` of dimensions can
|
||
|
/// be safely converted to the C API array ::dnnl_dims_t. Throws if
|
||
|
/// validation fails.
|
||
|
///
|
||
|
/// @param v Vector of dimensions.
|
||
|
/// @param min_size Minimum expected size of the vector.
|
||
|
template <typename T>
|
||
|
static void validate_dims(const std::vector<T> &v, int min_size = 0) {
|
||
|
validate_container_size(
|
||
|
v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
|
||
|
}
|
||
|
|
||
|
/// Data type specification.
|
||
|
enum class data_type {
|
||
|
/// Undefined data type (used for empty memory descriptors).
|
||
|
undef = dnnl_data_type_undef,
|
||
|
/// 16-bit/half-precision floating point.
|
||
|
f16 = dnnl_f16,
|
||
|
/// non-standard 16-bit floating point with 7-bit mantissa.
|
||
|
bf16 = dnnl_bf16,
|
||
|
/// 32-bit/single-precision floating point.
|
||
|
f32 = dnnl_f32,
|
||
|
/// 32-bit signed integer.
|
||
|
s32 = dnnl_s32,
|
||
|
/// 8-bit signed integer.
|
||
|
s8 = dnnl_s8,
|
||
|
/// 8-bit unsigned integer.
|
||
|
u8 = dnnl_u8,
|
||
|
};
|
||
|
|
||
|
/// Memory format kind
|
||
|
enum class format_kind {
|
||
|
/// Undefined memory format kind, used for empty memory descriptors.
|
||
|
undef = dnnl_format_kind_undef,
|
||
|
/// Unspecified format kind.
|
||
|
/// The primitive selects a format automatically.
|
||
|
any = dnnl_format_kind_any,
|
||
|
/// A tensor in a generic format described by the stride and blocking
|
||
|
/// values in each dimension. See @ref dnnl_blocking_desc_t for more
|
||
|
/// information.
|
||
|
blocked = dnnl_blocked,
|
||
|
/// Weights format used in 8bit Winograd convolution.
|
||
|
wino = dnnl_format_kind_wino,
|
||
|
/// Packed weights format used in RNN.
|
||
|
packed = dnnl_format_kind_rnn_packed,
|
||
|
};
|
||
|
|
||
|
/// Memory format tag specification.
|
||
|
///
|
||
|
/// Memory format tags can be further divided into two categories:
|
||
|
///
|
||
|
/// - Domain-agnostic names, i.e. names that do not depend on the tensor
|
||
|
/// usage in the specific primitive. These names use letters from `a`
|
||
|
/// to `l` to denote logical dimensions and form the order in which the
|
||
|
/// dimensions are laid in memory. For example,
|
||
|
/// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
|
||
|
/// second logical dimension (denoted as `b`) is the innermost, i.e.
|
||
|
/// has stride = 1, and the first logical dimension (`a`) is laid out in
|
||
|
/// memory with stride equal to the size of the second dimension. On the
|
||
|
/// other hand, #dnnl::memory::format_tag::ba is the transposed version
|
||
|
/// of the same tensor: the outermost dimension (`a`) becomes the
|
||
|
/// innermost one.
|
||
|
///
|
||
|
/// - Domain-specific names, i.e. names that make sense only in the
|
||
|
/// context of a certain domain, such as CNN. These names are
|
||
|
/// aliases to the corresponding domain-agnostic tags and used mostly
|
||
|
/// for convenience. For example, #dnnl::memory::format_tag::nc
|
||
|
/// is used to denote 2D CNN activations tensor memory format, where
|
||
|
/// the channels dimension is the innermost one and the batch dimension
|
||
|
/// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
|
||
|
/// an alias for #dnnl::memory::format_tag::ab, because for
|
||
|
/// CNN primitives the logical dimensions of activations tensors come
|
||
|
/// in order: batch, channels, spatial. In other words, batch
|
||
|
/// corresponds to the first logical dimension (`a`), and channels
|
||
|
/// correspond to the second one (`b`).
|
||
|
///
|
||
|
/// The following domain-specific notation applies to memory format tags:
|
||
|
/// - @c 'n' denotes the mini-batch dimension
|
||
|
/// - @c 'c' denotes a channels dimension
|
||
|
/// - When there are multiple channel dimensions (for example,
|
||
|
/// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
|
||
|
/// of input and output channels
|
||
|
/// - @c 'g' denotes a groups dimension for convolution weights
|
||
|
/// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
|
||
|
/// respectively
|
||
|
///
|
||
|
/// See @ref dnnl_format_tag_t for a detailed description.
|
||
|
enum class format_tag {
|
||
|
/// Undefined memory format tag
|
||
|
undef = dnnl_format_tag_undef,
|
||
|
/// Placeholder memory format tag. Used to instruct the primitive to
|
||
|
/// select a format automatically.
|
||
|
any = dnnl_format_tag_any,
|
||
|
|
||
|
/// plain 1D tensor
|
||
|
a = dnnl_a,
|
||
|
|
||
|
/// plain 2D tensor
|
||
|
ab = dnnl_ab,
|
||
|
/// permuted 2D tensor
|
||
|
ba = dnnl_ba,
|
||
|
|
||
|
/// plain 3D tensor
|
||
|
abc = dnnl_abc,
|
||
|
/// permuted 3D tensor
|
||
|
acb = dnnl_acb,
|
||
|
/// permuted 3D tensor
|
||
|
bac = dnnl_bac,
|
||
|
/// permuted 3D tensor
|
||
|
bca = dnnl_bca,
|
||
|
/// permuted 3D tensor
|
||
|
cba = dnnl_cba,
|
||
|
|
||
|
/// plain 4D tensor
|
||
|
abcd = dnnl_abcd,
|
||
|
/// permuted 4D tensor
|
||
|
abdc = dnnl_abdc,
|
||
|
/// permuted 4D tensor
|
||
|
acdb = dnnl_acdb,
|
||
|
/// permuted 4D tensor
|
||
|
bacd = dnnl_bacd,
|
||
|
/// permuted 4D tensor
|
||
|
bcda = dnnl_bcda,
|
||
|
/// permuted 4D tensor
|
||
|
cdba = dnnl_cdba,
|
||
|
/// permuted 4D tensor
|
||
|
dcab = dnnl_dcab,
|
||
|
|
||
|
/// plain 5D tensor
|
||
|
abcde = dnnl_abcde,
|
||
|
/// permuted 5D tensor
|
||
|
abdec = dnnl_abdec,
|
||
|
/// permuted 5D tensor
|
||
|
acbde = dnnl_acbde,
|
||
|
/// permuted 5D tensor
|
||
|
acdeb = dnnl_acdeb,
|
||
|
/// permuted 5D tensor
|
||
|
bcdea = dnnl_bcdea,
|
||
|
/// permuted 5D tensor
|
||
|
cdeba = dnnl_cdeba,
|
||
|
/// permuted 5D tensor
|
||
|
decab = dnnl_decab,
|
||
|
/// plain 6D tensor
|
||
|
abcdef = dnnl_abcdef,
|
||
|
/// plain 6D tensor
|
||
|
acbdef = dnnl_acbdef,
|
||
|
/// plain 6D tensor
|
||
|
defcab = dnnl_defcab,
|
||
|
|
||
|
/// 1D tensor; an alias for #dnnl::memory::format_tag::a
|
||
|
x = a,
|
||
|
/// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
|
||
|
nc = ab,
|
||
|
/// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
|
||
|
cn = ba,
|
||
|
/// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
|
||
|
tn = ab,
|
||
|
/// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
|
||
|
nt = ba,
|
||
|
/// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
|
||
|
ncw = abc,
|
||
|
/// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
|
||
|
nwc = acb,
|
||
|
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
|
||
|
nchw = abcd,
|
||
|
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
|
||
|
nhwc = acdb,
|
||
|
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
|
||
|
chwn = bcda,
|
||
|
/// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
|
||
|
ncdhw = abcde,
|
||
|
/// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
|
||
|
ndhwc = acdeb,
|
||
|
|
||
|
/// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
|
||
|
oi = ab,
|
||
|
/// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
|
||
|
io = ba,
|
||
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
|
||
|
oiw = abc,
|
||
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
|
||
|
owi = acb,
|
||
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
|
||
|
wio = cba,
|
||
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
|
||
|
iwo = bca,
|
||
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
|
||
|
oihw = abcd,
|
||
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
|
||
|
hwio = cdba,
|
||
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
|
||
|
ohwi = acdb,
|
||
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
|
||
|
ihwo = bcda,
|
||
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
|
||
|
iohw = bacd,
|
||
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
|
||
|
oidhw = abcde,
|
||
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
|
||
|
dhwio = cdeba,
|
||
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
|
||
|
odhwi = acdeb,
|
||
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
|
||
|
idhwo = bcdea,
|
||
|
|
||
|
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
|
||
|
goiw = abcd,
|
||
|
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
|
||
|
wigo = dcab,
|
||
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
|
||
|
goihw = abcde,
|
||
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
|
||
|
hwigo = decab,
|
||
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
|
||
|
giohw = acbde,
|
||
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
|
||
|
goidhw = abcdef,
|
||
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
|
||
|
giodhw = acbdef,
|
||
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
|
||
|
dhwigo = defcab,
|
||
|
|
||
|
/// 3D RNN data tensor in the format (seq_length, batch, input channels).
|
||
|
tnc = abc,
|
||
|
/// 3D RNN data tensor in the format (batch, seq_length, input channels).
|
||
|
ntc = bac,
|
||
|
/// 4D RNN states tensor in the format (num_layers, num_directions,
|
||
|
/// batch, state channels).
|
||
|
ldnc = abcd,
|
||
|
/// 5D RNN weights tensor in the format (num_layers, num_directions,
|
||
|
/// input_channels, num_gates, output_channels).
|
||
|
///
|
||
|
/// - For LSTM cells, the gates order is input, forget, candidate
|
||
|
/// and output gate.
|
||
|
/// - For GRU cells, the gates order is update, reset and output gate.
|
||
|
ldigo = abcde,
|
||
|
/// 5D RNN weights tensor in the format (num_layers, num_directions,
|
||
|
/// num_gates, output_channels, input_channels).
|
||
|
///
|
||
|
/// - For LSTM cells, the gates order is input, forget, candidate
|
||
|
/// and output gate.
|
||
|
/// - For GRU cells, the gates order is update, reset and output gate.
|
||
|
ldgoi = abdec,
|
||
|
/// 4D LSTM projection tensor in the format (num_layers, num_directions,
|
||
|
/// num_channels_in_hidden_state, num_channels_in_recurrent_projection).
|
||
|
ldio = abcd,
|
||
|
/// 4D LSTM projection tensor in the format (num_layers, num_directions,
|
||
|
/// num_channels_in_recurrent_projection, num_channels_in_hidden_state).
|
||
|
ldoi = abdc,
|
||
|
/// 4D RNN bias tensor in the format (num_layers, num_directions,
|
||
|
/// num_gates, output_channels).
|
||
|
///
|
||
|
/// - For LSTM cells, the gates order is input, forget, candidate
|
||
|
/// and output gate.
|
||
|
/// - For GRU cells, the gates order is update, reset and output gate.
|
||
|
ldgo = abcd,
|
||
|
|
||
|
// Opaque blocked formats
|
||
|
|
||
|
Abc16a = dnnl_Abc16a,
|
||
|
ABc16a16b = dnnl_ABc16a16b,
|
||
|
ABc4a4b = dnnl_ABc4a4b,
|
||
|
aBc16b = dnnl_aBc16b,
|
||
|
ABc16b16a = dnnl_ABc16b16a,
|
||
|
Abc4a = dnnl_Abc4a,
|
||
|
aBc4b = dnnl_aBc4b,
|
||
|
ABc4b16a4b = dnnl_ABc4b16a4b,
|
||
|
ABc2b8a4b = dnnl_ABc2b8a4b,
|
||
|
ABc4b4a = dnnl_ABc4b4a,
|
||
|
ABc8a16b2a = dnnl_ABc8a16b2a,
|
||
|
ABc8a8b = dnnl_ABc8a8b,
|
||
|
aBc8b = dnnl_aBc8b,
|
||
|
ABc8b16a2b = dnnl_ABc8b16a2b,
|
||
|
ABc8b8a = dnnl_ABc8b8a,
|
||
|
Abcd16a = dnnl_Abcd16a,
|
||
|
ABcd16a16b = dnnl_ABcd16a16b,
|
||
|
aBcd16b = dnnl_aBcd16b,
|
||
|
ABcd16b16a = dnnl_ABcd16b16a,
|
||
|
aBCd16b16c = dnnl_aBCd16b16c,
|
||
|
aBCd16c16b = dnnl_aBCd16c16b,
|
||
|
Abcd4a = dnnl_Abcd4a,
|
||
|
aBcd4b = dnnl_aBcd4b,
|
||
|
ABcd4b16a4b = dnnl_ABcd4b16a4b,
|
||
|
ABcd2b8a4b = dnnl_ABcd2b8a4b,
|
||
|
ABcd4b4a = dnnl_ABcd4b4a,
|
||
|
ABcd4a4b = dnnl_ABcd4a4b,
|
||
|
aBCd4c16b4c = dnnl_aBCd4c16b4c,
|
||
|
aBCd2c8b4c = dnnl_aBCd2c8b4c,
|
||
|
aBCd4c4b = dnnl_aBCd4c4b,
|
||
|
aBCd4b4c = dnnl_aBCd4b4c,
|
||
|
ABcd8a16b2a = dnnl_ABcd8a16b2a,
|
||
|
ABcd8a8b = dnnl_ABcd8a8b,
|
||
|
/// 4D tensor blocked by 2nd dimension with block size 8
|
||
|
aBcd8b = dnnl_aBcd8b,
|
||
|
ABcd8b16a2b = dnnl_ABcd8b16a2b,
|
||
|
aBCd8b16c2b = dnnl_aBCd8b16c2b,
|
||
|
/// 4D tensor blocked by 1st and 2nd dimension with block size 8
|
||
|
ABcd8b8a = dnnl_ABcd8b8a,
|
||
|
aBCd8b8c = dnnl_aBCd8b8c,
|
||
|
aBCd8c16b2c = dnnl_aBCd8c16b2c,
|
||
|
aBCd8c8b = dnnl_aBCd8c8b,
|
||
|
Abcde16a = dnnl_Abcde16a,
|
||
|
ABcde16a16b = dnnl_ABcde16a16b,
|
||
|
aBcde16b = dnnl_aBcde16b,
|
||
|
ABcde16b16a = dnnl_ABcde16b16a,
|
||
|
aBCde16b16c = dnnl_aBCde16b16c,
|
||
|
aBCde16c16b = dnnl_aBCde16c16b,
|
||
|
aBCde2c8b4c = dnnl_aBCde2c8b4c,
|
||
|
Abcde4a = dnnl_Abcde4a,
|
||
|
aBcde4b = dnnl_aBcde4b,
|
||
|
ABcde4b4a = dnnl_ABcde4b4a,
|
||
|
ABcde4a4b = dnnl_ABcde4a4b,
|
||
|
aBCde4b4c = dnnl_aBCde4b4c,
|
||
|
aBCde4c16b4c = dnnl_aBCde4c16b4c,
|
||
|
aBCde4c4b = dnnl_aBCde4c4b,
|
||
|
Abcde8a = dnnl_Abcde8a,
|
||
|
ABcde8a8b = dnnl_ABcde8a8b,
|
||
|
aBcde8b = dnnl_aBcde8b,
|
||
|
ABcde8b16a2b = dnnl_ABcde8b16a2b,
|
||
|
ABcde4b16a4b = dnnl_ABcde4b16a4b,
|
||
|
ABcde2b8a4b = dnnl_ABcde2b8a4b,
|
||
|
aBCde8b16c2b = dnnl_aBCde8b16c2b,
|
||
|
ABcde8b8a = dnnl_ABcde8b8a,
|
||
|
aBCde8b8c = dnnl_aBCde8b8c,
|
||
|
ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
|
||
|
ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
|
||
|
aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
|
||
|
aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
|
||
|
aBCde8c16b2c = dnnl_aBCde8c16b2c,
|
||
|
aBCde8c8b = dnnl_aBCde8c8b,
|
||
|
aBcdef16b = dnnl_aBcdef16b,
|
||
|
aBCdef16b16c = dnnl_aBCdef16b16c,
|
||
|
aBCdef16c16b = dnnl_aBCdef16c16b,
|
||
|
aBcdef4b = dnnl_aBcdef4b,
|
||
|
aBCdef4c4b = dnnl_aBCdef4c4b,
|
||
|
aBCdef4b4c = dnnl_aBCdef4b4c,
|
||
|
aBCdef8b8c = dnnl_aBCdef8b8c,
|
||
|
aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
|
||
|
aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
|
||
|
aBCdef8c8b = dnnl_aBCdef8c8b,
|
||
|
aBdc16b = dnnl_aBdc16b,
|
||
|
aBdc4b = dnnl_aBdc4b,
|
||
|
aBdc8b = dnnl_aBdc8b,
|
||
|
aBdec16b = dnnl_aBdec16b,
|
||
|
aBdec4b = dnnl_aBdec4b,
|
||
|
aBdec8b = dnnl_aBdec8b,
|
||
|
aBdefc16b = dnnl_aBdefc16b,
|
||
|
aCBdef16c16b = dnnl_aCBdef16c16b,
|
||
|
aCBdef16b16c = dnnl_aCBdef16b16c,
|
||
|
aBdefc4b = dnnl_aBdefc4b,
|
||
|
aBdefc8b = dnnl_aBdefc8b,
|
||
|
Acb16a = dnnl_Acb16a,
|
||
|
Acb4a = dnnl_Acb4a,
|
||
|
Acb8a = dnnl_Acb8a,
|
||
|
aCBd16b16c = dnnl_aCBd16b16c,
|
||
|
aCBd16c16b = dnnl_aCBd16c16b,
|
||
|
aCBde16b16c = dnnl_aCBde16b16c,
|
||
|
aCBde16c16b = dnnl_aCBde16c16b,
|
||
|
Acdb16a = dnnl_Acdb16a,
|
||
|
Acdb4a = dnnl_Acdb4a,
|
||
|
Acdb8a = dnnl_Acdb8a,
|
||
|
Acdeb16a = dnnl_Acdeb16a,
|
||
|
Acdeb4a = dnnl_Acdeb4a,
|
||
|
Acdeb8a = dnnl_Acdeb8a,
|
||
|
BAc16a16b = dnnl_BAc16a16b,
|
||
|
BAc16b16a = dnnl_BAc16b16a,
|
||
|
BAcd16a16b = dnnl_BAcd16a16b,
|
||
|
BAcd16b16a = dnnl_BAcd16b16a,
|
||
|
ABcd32a32b = dnnl_ABcd32a32b,
|
||
|
BAcde16b16a = dnnl_BAcde16b16a,
|
||
|
BAcde16a16b = dnnl_BAcde16a16b,
|
||
|
aBdec32b = dnnl_aBdec32b,
|
||
|
Abcdef16a = dnnl_Abcdef16a,
|
||
|
Acdb32a = dnnl_Acdb32a,
|
||
|
aBCd2b4c2b = dnnl_aBCd2b4c2b,
|
||
|
aBCde2b4c2b = dnnl_aBCde2b4c2b,
|
||
|
aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
|
||
|
aBCd2c4b2c = dnnl_aBCd2c4b2c,
|
||
|
aBCde2c4b2c = dnnl_aBCde2c4b2c,
|
||
|
aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
|
||
|
aBCd4b8c2b = dnnl_aBCd4b8c2b,
|
||
|
aBCde4b8c2b = dnnl_aBCde4b8c2b,
|
||
|
aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
|
||
|
aBCd4c8b2c = dnnl_aBCd4c8b2c,
|
||
|
aBCde4c8b2c = dnnl_aBCde4c8b2c,
|
||
|
aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
|
||
|
|
||
|
format_tag_last = dnnl_format_tag_last,
|
||
|
|
||
|
nCdhw16c = dnnl_nCdhw16c,
|
||
|
nCdhw4c = dnnl_nCdhw4c,
|
||
|
nCdhw8c = dnnl_nCdhw8c,
|
||
|
nChw16c = dnnl_nChw16c,
|
||
|
nChw4c = dnnl_nChw4c,
|
||
|
nChw8c = dnnl_nChw8c,
|
||
|
nCw16c = dnnl_nCw16c,
|
||
|
nCw4c = dnnl_nCw4c,
|
||
|
nCw8c = dnnl_nCw8c,
|
||
|
NCw16n16c = dnnl_NCw16n16c,
|
||
|
NChw16n16c = dnnl_NChw16n16c,
|
||
|
NCdhw16n16c = dnnl_NCdhw16n16c,
|
||
|
NChw32n32c = dnnl_NChw32n32c,
|
||
|
IOhw16i16o = dnnl_IOhw16i16o,
|
||
|
Ohwi32o = dnnl_Ohwi32o,
|
||
|
IOdhw16i16o = dnnl_IOdhw16i16o,
|
||
|
gIOhw16i16o = dnnl_gIOhw16i16o,
|
||
|
gOhwi32o = dnnl_gOhwi32o,
|
||
|
Goidhw16g = dnnl_Goidhw16g,
|
||
|
IOw16o16i = dnnl_IOw16o16i,
|
||
|
OIw16i16o = dnnl_OIw16i16o,
|
||
|
IOw16i16o = dnnl_IOw16i16o,
|
||
|
gIOw16i16o = dnnl_gIOw16i16o,
|
||
|
OIw16o16i = dnnl_OIw16o16i,
|
||
|
Oiw16o = dnnl_Oiw16o,
|
||
|
OIw4i16o4i = dnnl_OIw4i16o4i,
|
||
|
OIw2i8o4i = dnnl_OIw2i8o4i,
|
||
|
OIw4i4o = dnnl_OIw4i4o,
|
||
|
OIw4o4i = dnnl_OIw4o4i,
|
||
|
Oiw4o = dnnl_Oiw4o,
|
||
|
OIw8i16o2i = dnnl_OIw8i16o2i,
|
||
|
OIw8i8o = dnnl_OIw8i8o,
|
||
|
OIw8o16i2o = dnnl_OIw8o16i2o,
|
||
|
OIw8o8i = dnnl_OIw8o8i,
|
||
|
Owi16o = dnnl_Owi16o,
|
||
|
OwI16o2i = dnnl_OwI16o2i,
|
||
|
Owi4o = dnnl_Owi4o,
|
||
|
Owi8o = dnnl_Owi8o,
|
||
|
IOhw16o16i = dnnl_IOhw16o16i,
|
||
|
Ohwi16o = dnnl_Ohwi16o,
|
||
|
OhwI16o2i = dnnl_OhwI16o2i,
|
||
|
Ohwi4o = dnnl_Ohwi4o,
|
||
|
Ohwi8o = dnnl_Ohwi8o,
|
||
|
OIhw16i16o = dnnl_OIhw16i16o,
|
||
|
OIhw16o16i = dnnl_OIhw16o16i,
|
||
|
Oihw16o = dnnl_Oihw16o,
|
||
|
OIhw4i16o4i = dnnl_OIhw4i16o4i,
|
||
|
OIhw4i4o = dnnl_OIhw4i4o,
|
||
|
OIhw4o4i = dnnl_OIhw4o4i,
|
||
|
Oihw4o = dnnl_Oihw4o,
|
||
|
OIhw8i16o2i = dnnl_OIhw8i16o2i,
|
||
|
OIhw8i8o = dnnl_OIhw8i8o,
|
||
|
OIhw8o16i2o = dnnl_OIhw8o16i2o,
|
||
|
OIhw8o8i = dnnl_OIhw8o8i,
|
||
|
OIhw2i8o4i = dnnl_OIhw2i8o4i,
|
||
|
IOdhw16o16i = dnnl_IOdhw16o16i,
|
||
|
Odhwi16o = dnnl_Odhwi16o,
|
||
|
OdhwI16o2i = dnnl_OdhwI16o2i,
|
||
|
Odhwi4o = dnnl_Odhwi4o,
|
||
|
Odhwi8o = dnnl_Odhwi8o,
|
||
|
OIdhw16i16o = dnnl_OIdhw16i16o,
|
||
|
OIdhw16o16i = dnnl_OIdhw16o16i,
|
||
|
Oidhw16o = dnnl_Oidhw16o,
|
||
|
OIdhw4i4o = dnnl_OIdhw4i4o,
|
||
|
OIdhw4o4i = dnnl_OIdhw4o4i,
|
||
|
Oidhw4o = dnnl_Oidhw4o,
|
||
|
OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
|
||
|
OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
|
||
|
OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
|
||
|
OIdhw8i8o = dnnl_OIdhw8i8o,
|
||
|
OIdhw8o8i = dnnl_OIdhw8o8i,
|
||
|
gIOw16o16i = dnnl_gIOw16o16i,
|
||
|
gOIw16i16o = dnnl_gOIw16i16o,
|
||
|
gOIw16o16i = dnnl_gOIw16o16i,
|
||
|
gOiw16o = dnnl_gOiw16o,
|
||
|
gOIw4i16o4i = dnnl_gOIw4i16o4i,
|
||
|
gOIw2i8o4i = dnnl_gOIw2i8o4i,
|
||
|
gOIw4i4o = dnnl_gOIw4i4o,
|
||
|
gOIw4o4i = dnnl_gOIw4o4i,
|
||
|
gOiw4o = dnnl_gOiw4o,
|
||
|
gOIw8i16o2i = dnnl_gOIw8i16o2i,
|
||
|
gOIw8i8o = dnnl_gOIw8i8o,
|
||
|
gOIw8o16i2o = dnnl_gOIw8o16i2o,
|
||
|
gOIw8o8i = dnnl_gOIw8o8i,
|
||
|
gOwi16o = dnnl_gOwi16o,
|
||
|
gOwI16o2i = dnnl_gOwI16o2i,
|
||
|
gOwi4o = dnnl_gOwi4o,
|
||
|
gOwi8o = dnnl_gOwi8o,
|
||
|
Goiw8g = dnnl_Goiw8g,
|
||
|
Goiw16g = dnnl_Goiw16g,
|
||
|
gIOhw16o16i = dnnl_gIOhw16o16i,
|
||
|
gOhwi16o = dnnl_gOhwi16o,
|
||
|
gOhwI16o2i = dnnl_gOhwI16o2i,
|
||
|
gOhwi4o = dnnl_gOhwi4o,
|
||
|
gOhwi8o = dnnl_gOhwi8o,
|
||
|
Goihw16g = dnnl_Goihw16g,
|
||
|
gOIhw16i16o = dnnl_gOIhw16i16o,
|
||
|
gOIhw16o16i = dnnl_gOIhw16o16i,
|
||
|
gOihw16o = dnnl_gOihw16o,
|
||
|
gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
|
||
|
gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
|
||
|
gOIhw4i4o = dnnl_gOIhw4i4o,
|
||
|
gOIhw4o4i = dnnl_gOIhw4o4i,
|
||
|
gOihw4o = dnnl_gOihw4o,
|
||
|
Goihw8g = dnnl_Goihw8g,
|
||
|
gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
|
||
|
gOIhw8i8o = dnnl_gOIhw8i8o,
|
||
|
gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
|
||
|
OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
|
||
|
OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
|
||
|
gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
|
||
|
gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
|
||
|
gOIhw8o8i = dnnl_gOIhw8o8i,
|
||
|
gIOdhw16i16o = dnnl_gIOdhw16i16o,
|
||
|
gIOdhw16o16i = dnnl_gIOdhw16o16i,
|
||
|
gOdhwi16o = dnnl_gOdhwi16o,
|
||
|
gOdhwI16o2i = dnnl_gOdhwI16o2i,
|
||
|
gOdhwi4o = dnnl_gOdhwi4o,
|
||
|
gOdhwi8o = dnnl_gOdhwi8o,
|
||
|
gOIdhw16i16o = dnnl_gOIdhw16i16o,
|
||
|
gOIdhw16o16i = dnnl_gOIdhw16o16i,
|
||
|
gOidhw16o = dnnl_gOidhw16o,
|
||
|
gOIdhw4i4o = dnnl_gOIdhw4i4o,
|
||
|
gOIdhw4o4i = dnnl_gOIdhw4o4i,
|
||
|
gOidhw4o = dnnl_gOidhw4o,
|
||
|
gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
|
||
|
gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
|
||
|
gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
|
||
|
gOIdhw8i8o = dnnl_gOIdhw8i8o,
|
||
|
gOIdhw8o8i = dnnl_gOIdhw8o8i,
|
||
|
gOIw2i4o2i = dnnl_gOIw2i4o2i,
|
||
|
gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
|
||
|
gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
|
||
|
gOIw2o4i2o = dnnl_gOIw2o4i2o,
|
||
|
gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
|
||
|
gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
|
||
|
gOIw4i8o2i = dnnl_gOIw4i8o2i,
|
||
|
gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
|
||
|
gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
|
||
|
gOIw4o8i2o = dnnl_gOIw4o8i2o,
|
||
|
gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
|
||
|
gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
|
||
|
};
|
||
|
|
||
|
/// A memory descriptor.
|
||
|
struct desc {
|
||
|
friend struct memory;
|
||
|
/// The underlying C API data structure.
|
||
|
dnnl_memory_desc_t data;
|
||
|
|
||
|
/// Constructs a zero (empty) memory descriptor. Such a memory
|
||
|
/// descriptor can be used to indicate absence of an argument.
|
||
|
desc() : data() {}
|
||
|
|
||
|
/// Constructs a memory descriptor.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The logical order of dimensions corresponds to the `abc...`
|
||
|
/// format tag, and the physical meaning of the dimensions depends
|
||
|
/// both on the primitive that would operate on this memory and
|
||
|
/// the operation context.
|
||
|
///
|
||
|
/// @param dims Tensor dimensions.
|
||
|
/// @param data_type Data precision/type.
|
||
|
/// @param format_tag Memory format tag.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case a
|
||
|
/// zero memory descriptor will be constructed. This flag is
|
||
|
/// optional and defaults to false.
|
||
|
desc(const memory::dims &dims, data_type data_type,
|
||
|
format_tag format_tag, bool allow_empty = false)
|
||
|
: data() {
|
||
|
validate_dims(dims);
|
||
|
dnnl_status_t status = dnnl_memory_desc_init_by_tag(&data,
|
||
|
(int)dims.size(), dims.data(), convert_to_c(data_type),
|
||
|
convert_to_c(format_tag));
|
||
|
if (!allow_empty)
|
||
|
error::wrap_c_api(status,
|
||
|
"could not construct a memory descriptor using a "
|
||
|
"format tag");
|
||
|
}
|
||
|
|
||
|
/// Constructs a memory descriptor by strides.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The logical order of dimensions corresponds to the `abc...`
|
||
|
/// format tag, and the physical meaning of the dimensions depends
|
||
|
/// both on the primitive that would operate on this memory and
|
||
|
/// the operation context.
|
||
|
///
|
||
|
/// @param dims Tensor dimensions.
|
||
|
/// @param data_type Data precision/type.
|
||
|
/// @param strides Strides for each dimension.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case a
|
||
|
/// zero memory descriptor will be constructed. This flag is
|
||
|
/// optional and defaults to false.
|
||
|
desc(const memory::dims &dims, data_type data_type,
|
||
|
const memory::dims &strides, bool allow_empty = false)
|
||
|
: data() {
|
||
|
validate_dims(dims);
|
||
|
if (!strides.empty()) validate_dims(strides, (int)dims.size());
|
||
|
dnnl_status_t status = dnnl_memory_desc_init_by_strides(&data,
|
||
|
(int)dims.size(), dims.data(), convert_to_c(data_type),
|
||
|
strides.empty() ? nullptr : &strides[0]);
|
||
|
if (!allow_empty)
|
||
|
error::wrap_c_api(status,
|
||
|
"could not construct a memory descriptor using "
|
||
|
"strides");
|
||
|
}
|
||
|
|
||
|
/// Constructs a memory descriptor from a C API data structure.
|
||
|
///
|
||
|
/// @param data A C API ::dnnl_memory_desc_t structure.
|
||
|
desc(const dnnl_memory_desc_t &data) : data(data) {}
|
||
|
|
||
|
/// Constructs a memory descriptor for a region inside an area
|
||
|
/// described by this memory descriptor.
|
||
|
//
|
||
|
/// @param dims Sizes of the region.
|
||
|
/// @param offsets Offsets to the region from the encompassing
|
||
|
/// memory object in each dimension.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case a
|
||
|
/// zero memory descriptor will be returned. This flag is optional
|
||
|
/// and defaults to false.
|
||
|
/// @returns A memory descriptor for the region.
|
||
|
desc submemory_desc(const memory::dims &dims,
|
||
|
const memory::dims &offsets, bool allow_empty = false) const {
|
||
|
validate_dims(dims, data.ndims);
|
||
|
validate_dims(offsets, data.ndims);
|
||
|
dnnl_memory_desc_t sub_md = dnnl_memory_desc_t();
|
||
|
dnnl_status_t status = dnnl_memory_desc_init_submemory(
|
||
|
&sub_md, &data, dims.data(), offsets.data());
|
||
|
if (!allow_empty)
|
||
|
error::wrap_c_api(status, "could not construct a sub-memory");
|
||
|
return desc(sub_md);
|
||
|
}
|
||
|
|
||
|
/// Constructs a memory descriptor by reshaping an existing one. The
|
||
|
/// new memory descriptor inherits the data type. This operation is
|
||
|
/// valid only for memory descriptors that have format_kind set to
|
||
|
/// #dnnl::memory::format_kind::blocked or
|
||
|
/// #dnnl::memory::format_kind::any.
|
||
|
///
|
||
|
/// The operation ensures that the transformation of the physical memory
|
||
|
/// format corresponds to the transformation of the logical dimensions.
|
||
|
/// If such transformation is impossible, the function either throws an
|
||
|
/// exception (default) or returns a zero memory descriptor depending on
|
||
|
/// the `allow_empty` flag.
|
||
|
///
|
||
|
/// The reshape operation can be described as a combination of the
|
||
|
/// following basic operations:
|
||
|
/// 1. Add a dimension of size `1`. This is always possible.
|
||
|
/// 2. Remove a dimension of size `1`. This is possible only if the
|
||
|
/// dimension has no padding (i.e.
|
||
|
/// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
|
||
|
/// 3. Split a dimension into multiple ones. This is possible only if
|
||
|
/// the size of the dimension is exactly equal to the product of the
|
||
|
/// split ones and the dimension does not have padding (i.e.
|
||
|
/// `padded_dims[dim] = dims[dim]`).
|
||
|
/// 4. Joining multiple consecutive dimensions into a single one. As in
|
||
|
/// the cases above, this requires that the dimensions do not have
|
||
|
/// padding and that the memory format is such that in physical
|
||
|
/// memory these dimensions are dense and have the same order as
|
||
|
/// their logical counterparts. This also assumes that these
|
||
|
/// dimensions are not blocked.
|
||
|
/// - Here, dense means:
|
||
|
/// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
|
||
|
/// - And same order means:
|
||
|
/// `i < j <=> stride for dim[i] < stride for dim[j]`.
|
||
|
///
|
||
|
/// @warning
|
||
|
/// Some combinations of physical memory layout and/or offsets or
|
||
|
/// dimensions may result in a failure to make a reshape.
|
||
|
///
|
||
|
/// @param dims New dimensions. The product of dimensions must
|
||
|
/// remain constant.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case a
|
||
|
/// zero memory descriptor will be returned. This flag is optional
|
||
|
/// and defaults to false.
|
||
|
/// @returns A new memory descriptor with new dimensions.
|
||
|
desc reshape(const memory::dims &dims, bool allow_empty = false) const {
|
||
|
if (data.ndims) validate_dims(dims, 1);
|
||
|
dnnl_memory_desc_t out_md = dnnl_memory_desc_t();
|
||
|
dnnl_status_t status = dnnl_memory_desc_reshape(
|
||
|
&out_md, &data, (int)dims.size(), dims.data());
|
||
|
if (!allow_empty)
|
||
|
error::wrap_c_api(
|
||
|
status, "could not reshape a memory descriptor");
|
||
|
return desc(out_md);
|
||
|
}
|
||
|
|
||
|
/// Constructs a memory descriptor by permuting axes in an existing
|
||
|
/// one.
|
||
|
///
|
||
|
/// The physical memory layout representation is adjusted accordingly
|
||
|
/// to maintain the consistency between the logical and physical parts
|
||
|
/// of the memory descriptor.
|
||
|
///
|
||
|
/// The new memory descriptor inherits the data type. This operation is
|
||
|
/// valid only for memory descriptors that have format_kind set to
|
||
|
/// #dnnl::memory::format_kind::blocked or
|
||
|
/// #dnnl::memory::format_kind::any.
|
||
|
///
|
||
|
/// The logical axes will be permuted in the following manner:
|
||
|
/// ```
|
||
|
/// for (i: 0 .. ndims())
|
||
|
/// new_desc.dims()[permutation[i]] = dims()[i];
|
||
|
/// ```
|
||
|
///
|
||
|
/// Example:
|
||
|
/// @code
|
||
|
/// std::vector<int> permutation = {1, 0}; // swap the first and
|
||
|
/// // the second axes
|
||
|
/// dnnl::memory::desc in_md(
|
||
|
/// {2, 3}, data_type, memory::format_tag::ab);
|
||
|
/// dnnl::memory::desc expect_out_md(
|
||
|
/// {3, 2}, data_type, memory::format_tag::ba);
|
||
|
///
|
||
|
/// assert(in_md.permute_axes(permutation) == expect_out_md);
|
||
|
/// @endcode
|
||
|
///
|
||
|
/// @param permutation Axes permutation.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case a
|
||
|
/// zero memory descriptor will be returned. This flag is optional
|
||
|
/// and defaults to false.
|
||
|
/// @returns A new memory descriptor with new dimensions.
|
||
|
desc permute_axes(const std::vector<int> &permutation,
|
||
|
bool allow_empty = false) const {
|
||
|
validate_dims(permutation, data.ndims);
|
||
|
dnnl_memory_desc_t out_md = dnnl_memory_desc_t();
|
||
|
dnnl_status_t status = dnnl_memory_desc_permute_axes(
|
||
|
&out_md, &data, permutation.data());
|
||
|
if (!allow_empty)
|
||
|
error::wrap_c_api(status,
|
||
|
"could not permute axes of a memory descriptor");
|
||
|
return desc(out_md);
|
||
|
}
|
||
|
|
||
|
/// Returns dimensions of the memory descriptor.
|
||
|
///
|
||
|
/// Potentially expensive due to the data copy involved.
|
||
|
/// @returns A copy of the dimensions vector.
|
||
|
memory::dims dims() const {
|
||
|
return memory::dims(data.dims, data.dims + data.ndims);
|
||
|
}
|
||
|
|
||
|
/// Returns the data type of the memory descriptor.
|
||
|
/// @returns The data type.
|
||
|
memory::data_type data_type() const {
|
||
|
return static_cast<memory::data_type>(data.data_type);
|
||
|
}
|
||
|
|
||
|
/// Returns size of the memory descriptor in bytes.
|
||
|
/// @returns The number of bytes required to allocate a memory buffer
|
||
|
/// for the memory object described by this memory descriptor
|
||
|
/// including the padding area.
|
||
|
size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
|
||
|
|
||
|
/// Checks whether the memory descriptor is zero (empty).
|
||
|
/// @returns @c true if the memory descriptor describes an empty
|
||
|
/// memory and @c false otherwise.
|
||
|
bool is_zero() const { return data.ndims == 0; }
|
||
|
|
||
|
/// An equality operator.
|
||
|
/// @param other Another memory descriptor.
|
||
|
/// @returns Whether this and the other memory descriptors have
|
||
|
/// the same format tag, dimensions, strides, blocking, etc.
|
||
|
bool operator==(const desc &other) const {
|
||
|
return dnnl_memory_desc_equal(&data, &other.data) != 0;
|
||
|
}
|
||
|
|
||
|
/// An inequality operator.
|
||
|
/// @param other Another memory descriptor.
|
||
|
/// @returns Whether this and the other memory descriptors describe
|
||
|
/// different memory.
|
||
|
bool operator!=(const desc &other) const { return !operator==(other); }
|
||
|
};
|
||
|
|
||
|
// Default constructor.
|
||
|
//
|
||
|
// Constructs an empty memory object, which can be used to indicate absence
|
||
|
// of a parameter.
|
||
|
memory() = default;
|
||
|
|
||
|
/// Constructs a memory object.
|
||
|
///
|
||
|
/// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
|
||
|
/// object will have the underlying buffer set. In this case, the buffer
|
||
|
/// will be initialized as if #dnnl::memory::set_data_handle() had been
|
||
|
/// called.
|
||
|
///
|
||
|
/// @sa memory::set_data_handle()
|
||
|
///
|
||
|
/// @param md Memory descriptor.
|
||
|
/// @param engine Engine to store the data on.
|
||
|
/// @param handle Handle of the memory buffer to use as an underlying
|
||
|
/// storage.
|
||
|
/// - A pointer to the user-allocated buffer. In this case the library
|
||
|
/// doesn't own the buffer.
|
||
|
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
||
|
/// allocate the buffer for the memory object. In this case the
|
||
|
/// library owns the buffer.
|
||
|
/// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying
|
||
|
/// buffer.
|
||
|
memory(const desc &md, const engine &engine, void *handle) {
|
||
|
dnnl_memory_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_memory_create(&result, &md.data, engine.get(), handle),
|
||
|
"could not create a memory object");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a memory object.
|
||
|
///
|
||
|
/// The underlying storage for the memory will be allocated by the library.
|
||
|
///
|
||
|
/// @param md Memory descriptor.
|
||
|
/// @param engine Engine to store the data on.
|
||
|
memory(const desc &md, const engine &engine)
|
||
|
: memory(md, engine, DNNL_MEMORY_ALLOCATE) {}
|
||
|
|
||
|
/// Returns the associated memory descriptor.
|
||
|
desc get_desc() const {
|
||
|
const dnnl_memory_desc_t *cdesc;
|
||
|
error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
|
||
|
"could not get a memory descriptor from a memory object");
|
||
|
return desc(*cdesc);
|
||
|
}
|
||
|
|
||
|
/// Returns the associated engine.
|
||
|
engine get_engine() const {
|
||
|
dnnl_engine_t c_engine;
|
||
|
error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
|
||
|
"could not get an engine from a memory object");
|
||
|
return engine(c_engine, true);
|
||
|
}
|
||
|
|
||
|
/// Returns the underlying memory buffer.
|
||
|
///
|
||
|
/// On the CPU engine this is a pointer to the allocated memory.
|
||
|
void *get_data_handle() const {
|
||
|
void *handle;
|
||
|
error::wrap_c_api(dnnl_memory_get_data_handle(get(), &handle),
|
||
|
"could not get a native handle from a memory object");
|
||
|
return handle;
|
||
|
}
|
||
|
|
||
|
/// Sets data handle.
|
||
|
///
|
||
|
/// This function may write zero values to the memory specified by the @p
|
||
|
/// handle if the memory object has a zero padding area. This may be time
|
||
|
/// consuming and happens each time this function is called. The
|
||
|
/// operation is always blocking and the stream parameter is a hint.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The zero padding is required by memory objects created with
|
||
|
/// blocked memory format tags like #dnnl_aBcd8b when any of the
|
||
|
/// dimensions is not a multiple of the corresponding block size. For
|
||
|
/// "plain" formats like #dnnl::memory::format_tag::nchw or
|
||
|
/// #dnnl::memory::format_tag::nhwc zero padding area needs to be set
|
||
|
/// up explicitly when creating the corresponding memory descriptors.
|
||
|
/// See @ref dev_guide_understanding_memory_formats for more details.
|
||
|
///
|
||
|
/// @note
|
||
|
/// Even when the memory object is used to hold values that stay
|
||
|
/// constant during the execution of the program (pre-packed weights
|
||
|
/// during inference, for example), the function will still write
|
||
|
/// zeroes to the padding area if it exists. Hence, the @p handle
|
||
|
/// parameter cannot and does not have a const qualifier.
|
||
|
///
|
||
|
/// @param handle Data handle. For the CPU engine, the data handle
|
||
|
/// is a pointer to the actual data. For OpenCL it is a cl_mem.
|
||
|
/// @param stream Stream to use to execute padding in.
|
||
|
void set_data_handle(void *handle, const stream &stream) const {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_memory_set_data_handle_v2(get(), handle, stream.get(true)),
|
||
|
"could not set native handle of a memory object");
|
||
|
}
|
||
|
|
||
|
/// Sets data handle.
|
||
|
///
|
||
|
/// See documentation for
|
||
|
/// #dnnl::memory::set_data_handle(void *, const stream &) const
|
||
|
/// for more information.
|
||
|
///
|
||
|
/// @param handle Data handle. For the CPU engine, the data handle
|
||
|
/// is a pointer to the actual data. For OpenCL it is a cl_mem.
|
||
|
void set_data_handle(void *handle) const {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_memory_set_data_handle_v2(get(), handle, nullptr),
|
||
|
"could not set native handle of a memory object");
|
||
|
}
|
||
|
|
||
|
/// Maps a memory object and returns a host-side pointer to a memory
|
||
|
/// buffer with a copy of its contents.
|
||
|
///
|
||
|
/// Mapping enables read/write directly from/to the memory contents for
|
||
|
/// engines that do not support direct memory access.
|
||
|
///
|
||
|
/// Mapping is an exclusive operation - a memory object cannot be used in
|
||
|
/// other operations until it is unmapped via memory::unmap_data() call.
|
||
|
///
|
||
|
/// @note
|
||
|
/// Any primitives working with the memory should be completed before
|
||
|
/// the memory is mapped. Use stream::wait() to synchronize the
|
||
|
/// corresponding execution stream.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The map_data and unmap_data functions are provided mainly for
|
||
|
/// debug and testing purposes and their performance may be suboptimal.
|
||
|
///
|
||
|
/// @tparam T Data type to return a pointer to.
|
||
|
/// @returns Pointer to the mapped memory.
|
||
|
template <typename T = void>
|
||
|
T *map_data() const {
|
||
|
void *mapped_ptr;
|
||
|
error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
|
||
|
"could not map memory object data");
|
||
|
return static_cast<T *>(mapped_ptr);
|
||
|
}
|
||
|
|
||
|
/// Unmaps a memory object and writes back any changes made to the
|
||
|
/// previously mapped memory buffer. The pointer to the mapped buffer
|
||
|
/// must be obtained via the dnnl::memory::map_data() call.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The map_data and unmap_data functions are provided mainly for
|
||
|
/// debug and testing purposes and their performance may be suboptimal.
|
||
|
///
|
||
|
/// @param mapped_ptr A pointer previously returned by map_data().
|
||
|
void unmap_data(void *mapped_ptr) const {
|
||
|
error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
|
||
|
"could not unmap memory object data");
|
||
|
}
|
||
|
|
||
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
|
||
|
/// Returns the OpenCL memory object associated with the memory.
|
||
|
cl_mem get_ocl_mem_object() const {
|
||
|
cl_mem mem_object;
|
||
|
error::wrap_c_api(dnnl_memory_get_ocl_mem_object(get(), &mem_object),
|
||
|
"could not get OpenCL buffer object from a memory object");
|
||
|
return mem_object;
|
||
|
}
|
||
|
|
||
|
/// Sets the OpenCL memory object @p mem_object associated with the memory.
|
||
|
///
|
||
|
/// For behavioral details see memory::set_data_handle().
|
||
|
///
|
||
|
/// @param mem_object OpenCL cl_mem object to use as the underlying
|
||
|
/// storage. It must have at least get_desc().get_size() bytes
|
||
|
/// allocated.
|
||
|
void set_ocl_mem_object(cl_mem mem_object) {
|
||
|
error::wrap_c_api(dnnl_memory_set_ocl_mem_object(get(), mem_object),
|
||
|
"could not set OpenCL buffer object from a memory object");
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
static dnnl_data_type_t convert_to_c(data_type data_type) {
|
||
|
return static_cast<dnnl_data_type_t>(data_type);
|
||
|
}
|
||
|
static dnnl_format_tag_t convert_to_c(format_tag format) {
|
||
|
return static_cast<dnnl_format_tag_t>(format);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
|
||
|
return a == memory::convert_to_c(b);
|
||
|
}
|
||
|
inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
|
||
|
return !(a == b);
|
||
|
}
|
||
|
inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
|
||
|
return b == a;
|
||
|
}
|
||
|
inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
|
||
|
return !(a == b);
|
||
|
}
|
||
|
|
||
|
inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
|
||
|
return a == memory::convert_to_c(b);
|
||
|
}
|
||
|
inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
|
||
|
return !(a == b);
|
||
|
}
|
||
|
inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
|
||
|
return b == a;
|
||
|
}
|
||
|
inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
|
||
|
return !(a == b);
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_memory
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives
|
||
|
/// @{
|
||
|
/// @addtogroup dnnl_api_attributes Attributes
|
||
|
///
|
||
|
/// A container for parameters that extend primitives behavior.
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_post_ops_t> {
|
||
|
static dnnl_status_t destructor(dnnl_post_ops_t p) {
|
||
|
return dnnl_post_ops_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
/// @endcond
|
||
|
|
||
|
/// Post-ops.
|
||
|
///
|
||
|
/// Post-ops are computations executed after the main primitive computations
|
||
|
/// and are attached to the primitive via primitive attributes.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_attributes_post_ops
|
||
|
///
|
||
|
struct post_ops : public handle<dnnl_post_ops_t> {
|
||
|
using handle<dnnl_post_ops_t>::handle;
|
||
|
|
||
|
/// Constructs an empty sequence of post-ops.
|
||
|
post_ops() {
|
||
|
dnnl_post_ops_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_post_ops_create(&result), "could not create post-ops");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Returns the number of post-ops entries.
|
||
|
int len() const { return dnnl_post_ops_len(get()); }
|
||
|
|
||
|
/// Returns the primitive kind of post-op at entry with a certain index.
|
||
|
/// @param index Index of the post-op to return the kind for.
|
||
|
/// @returns Primitive kind of the post-op at the specified index.
|
||
|
primitive::kind kind(int index) const {
|
||
|
error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
|
||
|
"post-ops index is out of range");
|
||
|
return static_cast<primitive::kind>(
|
||
|
dnnl_post_ops_get_kind(get(), index));
|
||
|
}
|
||
|
|
||
|
/// Appends an accumulation (sum) post-op. Prior to accumulating the
|
||
|
/// result, the previous value would be multiplied by a scaling factor
|
||
|
/// @p scale.
|
||
|
///
|
||
|
/// The kind of this post-op is #dnnl::primitive::kind::sum.
|
||
|
///
|
||
|
/// This feature may improve performance for cases like residual learning
|
||
|
/// blocks, where the result of convolution is accumulated to the
|
||
|
/// previously computed activations. The parameter @p scale may be used
|
||
|
/// for the integer-based computations when the result and previous
|
||
|
/// activations have different logical scaling factors.
|
||
|
///
|
||
|
/// In the simplest case when the accumulation is the only post-op,
|
||
|
/// the computations would be:
|
||
|
///
|
||
|
/// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
|
||
|
///
|
||
|
/// @note
|
||
|
/// This post-op executes in-place and does not change the
|
||
|
/// destination layout.
|
||
|
///
|
||
|
/// @param scale Scaling factor.
|
||
|
void append_sum(float scale = 1.) {
|
||
|
error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale),
|
||
|
"could not append a sum post-op");
|
||
|
}
|
||
|
|
||
|
/// Returns the parameters of an accumulation (sum) post-op.
|
||
|
///
|
||
|
/// @param index Index of the sum post-op.
|
||
|
/// @param scale Scaling factor of the sum post-op.
|
||
|
void get_params_sum(int index, float &scale) const {
|
||
|
error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale),
|
||
|
"could not get parameters of a sum post-op");
|
||
|
}
|
||
|
|
||
|
/// Appends an elementwise post-op.
|
||
|
///
|
||
|
/// The kind of this post-op is #dnnl::primitive::kind::eltwise.
|
||
|
///
|
||
|
/// In the simplest case when the elementwise is the only post-op, the
|
||
|
/// computations would be:
|
||
|
///
|
||
|
/// dst[:] <- scale * eltwise_op (op(...)) // instead of dst[:] <- op(...)
|
||
|
///
|
||
|
/// where eltwise_op is configured with the given parameters.
|
||
|
///
|
||
|
/// @param scale Scaling factor.
|
||
|
/// @param algorithm Elementwise algorithm.
|
||
|
/// @param alpha Alpha parameter for the elementwise algorithm.
|
||
|
/// @param beta Beta parameter for the elementwise algorithm.
|
||
|
void append_eltwise(
|
||
|
float scale, algorithm algorithm, float alpha, float beta) {
|
||
|
error::wrap_c_api(dnnl_post_ops_append_eltwise(get(), scale,
|
||
|
convert_to_c(algorithm), alpha, beta),
|
||
|
"could not append an elementwise post-op");
|
||
|
}
|
||
|
|
||
|
/// Returns parameters of an elementwise post-up.
|
||
|
///
|
||
|
/// @param index Index of the post-op.
|
||
|
/// @param scale Output scaling factor.
|
||
|
/// @param algorithm Output elementwise algorithm kind.
|
||
|
/// @param alpha Output alpha parameter for the elementwise algorithm.
|
||
|
/// @param beta Output beta parameter for the elementwise algorithm.
|
||
|
void get_params_eltwise(int index, float &scale, algorithm &algorithm,
|
||
|
float &alpha, float &beta) const {
|
||
|
dnnl_alg_kind_t c_alg;
|
||
|
error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
|
||
|
get(), index, &scale, &c_alg, &alpha, &beta),
|
||
|
"could not get parameters of an elementwise post-op");
|
||
|
algorithm = static_cast<dnnl::algorithm>(c_alg);
|
||
|
}
|
||
|
|
||
|
/// Appends a depthwise post-op convolution with stride 1.
|
||
|
///
|
||
|
/// This post-op can only be fused with a 2D 1x1 convolution (convolution
|
||
|
/// with weights spatial dimension equal to 1 i.e., kh=kw=1).
|
||
|
///
|
||
|
/// The kind of this post-op is #dnnl_convolution.
|
||
|
///
|
||
|
/// The number of outputs for primitive remain same as before. The output
|
||
|
/// size remain same as the original primitive due to stride=1.
|
||
|
///
|
||
|
/// The Post-op can be defined as:
|
||
|
///
|
||
|
/// dst[:] <- scales * (conv_dw(conv_1x1))
|
||
|
///
|
||
|
/// See @ref dev_guide_attributes_post_ops_depthwise and
|
||
|
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
|
||
|
///
|
||
|
/// @param weights_data_type Weights data type of depthwise post-op
|
||
|
/// @param bias_data_type Bias data type of depthwise post-op
|
||
|
/// @param dst_data_type Output data type of depthwise post-op
|
||
|
/// @param mask Output scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the
|
||
|
/// @p scales array. The set i-th bit indicates that a dedicated output
|
||
|
/// scaling factor is used for each index along that dimension. The mask
|
||
|
/// value of 0 implies a common scaling factor for the whole output
|
||
|
/// tensor.
|
||
|
/// @param scales Output pointer to a constant array of float scaling
|
||
|
/// factors.
|
||
|
void append_dw_k3s1p1(memory::data_type weights_data_type,
|
||
|
memory::data_type bias_data_type, memory::data_type dst_data_type,
|
||
|
int mask, const std::vector<float> &scales) {
|
||
|
|
||
|
error::wrap_c_api(dnnl_post_ops_append_dw_k3s1p1(get(),
|
||
|
memory::convert_to_c(weights_data_type),
|
||
|
memory::convert_to_c(bias_data_type),
|
||
|
memory::convert_to_c(dst_data_type),
|
||
|
scales.size(), mask, &scales[0]),
|
||
|
"could not append depthwise post-op");
|
||
|
}
|
||
|
|
||
|
/// Returns the parameters of an depthwise post-op with stride 1.
|
||
|
///
|
||
|
/// @param index Index of the elementwise post-op.
|
||
|
/// @param weights_data_type Weights data type of depthwise post-op
|
||
|
/// @param bias_data_type Bias data type of depthwise post-op
|
||
|
/// @param dst_data_type Output data type of depthwise post-op
|
||
|
/// @param mask Output scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the
|
||
|
/// @p scales array. The set i-th bit indicates that a dedicated output
|
||
|
/// scaling factor is used for each index along that dimension. The mask
|
||
|
/// value of 0 implies a common scaling factor for the whole output
|
||
|
/// tensor.
|
||
|
/// @param scales Output pointer to a constant array of float scaling
|
||
|
/// factors.
|
||
|
void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type,
|
||
|
memory::data_type &bias_data_type, memory::data_type &dst_data_type,
|
||
|
int &mask, std::vector<float> &scales) const {
|
||
|
|
||
|
dnnl_data_type_t c_weights_data_type;
|
||
|
dnnl_data_type_t c_bias_data_type;
|
||
|
dnnl_data_type_t c_dst_data_type;
|
||
|
dnnl_dim_t count;
|
||
|
int c_mask;
|
||
|
const float *c_scales;
|
||
|
error::wrap_c_api(dnnl_post_ops_get_params_dw_k3s1p1(get(), index,
|
||
|
&c_weights_data_type, &c_bias_data_type,
|
||
|
&c_dst_data_type, &count, &c_mask, &c_scales),
|
||
|
"could not get parameters of depthwise post-op");
|
||
|
|
||
|
weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
|
||
|
bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
|
||
|
dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
|
||
|
scales.resize(count);
|
||
|
|
||
|
mask = c_mask;
|
||
|
for (dnnl_dim_t c = 0; c < count; ++c)
|
||
|
scales[c] = c_scales[c];
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
/// Appends a depthwise post-op convolution with stride 2.
|
||
|
///
|
||
|
/// This post-op can only be fused with a 2D 1x1 convolution (convolution
|
||
|
/// with weights spatial dimension equal to 1 i.e., kh=kw=1).
|
||
|
///
|
||
|
/// The kind of this post-op is #dnnl_convolution.
|
||
|
///
|
||
|
/// The number of outputs for primitive remain same as before. The output
|
||
|
/// spatial size can be derived as below:
|
||
|
///
|
||
|
/// output_height = ceil(output_height_1x1_convolution, stride)
|
||
|
/// output_width = ceil(output_width_1x1_convolution, stride)
|
||
|
///
|
||
|
/// The Post-op can be defined as:
|
||
|
///
|
||
|
/// dst[:] <- scales * (conv_dw(conv_1x1))
|
||
|
///
|
||
|
/// See @ref dev_guide_attributes_post_ops_depthwise and
|
||
|
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
|
||
|
///
|
||
|
/// @param weights_data_type Weights data type of depthwise post-op
|
||
|
/// @param bias_data_type Bias data type of depthwise post-op
|
||
|
/// @param dst_data_type Output data type of depthwise post-op
|
||
|
/// @param mask Output scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the
|
||
|
/// @p scales array. The set i-th bit indicates that a dedicated output
|
||
|
/// scaling factor is used for each index along that dimension. The mask
|
||
|
/// value of 0 implies a common scaling factor for the whole output
|
||
|
/// tensor.
|
||
|
/// @param scales Output pointer to a constant array of float scaling
|
||
|
/// factors.
|
||
|
/// @returns #dnnl_success on success and a status describing the error
|
||
|
/// otherwise
|
||
|
void append_dw_k3s2p1(memory::data_type weights_data_type,
|
||
|
memory::data_type bias_data_type, memory::data_type dst_data_type,
|
||
|
int mask, const std::vector<float> &scales) {
|
||
|
|
||
|
error::wrap_c_api(dnnl_post_ops_append_dw_k3s2p1(get(),
|
||
|
memory::convert_to_c(weights_data_type),
|
||
|
memory::convert_to_c(bias_data_type),
|
||
|
memory::convert_to_c(dst_data_type),
|
||
|
scales.size(), mask, &scales[0]),
|
||
|
"could not append depthwise post-op");
|
||
|
}
|
||
|
|
||
|
/// Returns the parameters of an depthwise post-op with stride 2.
|
||
|
///
|
||
|
/// @param index Index of the elementwise post-op.
|
||
|
/// @param weights_data_type Weights data type of depthwise post-op
|
||
|
/// @param bias_data_type Bias data type of depthwise post-op
|
||
|
/// @param dst_data_type Output data type of depthwise post-op
|
||
|
/// @param mask Output scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the
|
||
|
/// @p scales array. The set i-th bit indicates that a dedicated output
|
||
|
/// scaling factor is used for each index along that dimension. The mask
|
||
|
/// value of 0 implies a common scaling factor for the whole output
|
||
|
/// tensor.
|
||
|
/// @param scales Output pointer to a constant array of float scaling
|
||
|
/// factors.
|
||
|
void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type,
|
||
|
memory::data_type &bias_data_type, memory::data_type &dst_data_type,
|
||
|
int &mask, std::vector<float> &scales) const {
|
||
|
|
||
|
dnnl_data_type_t c_weights_data_type;
|
||
|
dnnl_data_type_t c_bias_data_type;
|
||
|
dnnl_data_type_t c_dst_data_type;
|
||
|
dnnl_dim_t count;
|
||
|
int c_mask;
|
||
|
const float *c_scales;
|
||
|
error::wrap_c_api(dnnl_post_ops_get_params_dw_k3s2p1(get(), index,
|
||
|
&c_weights_data_type, &c_bias_data_type,
|
||
|
&c_dst_data_type, &count, &c_mask, &c_scales),
|
||
|
"could not get parameters of depthwise post-op");
|
||
|
|
||
|
weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
|
||
|
bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
|
||
|
dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
|
||
|
scales.resize(count);
|
||
|
|
||
|
mask = c_mask;
|
||
|
for (dnnl_dim_t c = 0; c < count; ++c)
|
||
|
scales[c] = c_scales[c];
|
||
|
return;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
template <>
|
||
|
struct handle_traits<dnnl_primitive_attr_t> {
|
||
|
static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
|
||
|
return dnnl_primitive_attr_destroy(p);
|
||
|
}
|
||
|
};
|
||
|
/// @endcond
|
||
|
|
||
|
/// Primitive attributes
|
||
|
///
|
||
|
/// @sa @ref dev_guide_attributes
|
||
|
struct primitive_attr : public handle<dnnl_primitive_attr_t> {
|
||
|
using handle<dnnl_primitive_attr_t>::handle;
|
||
|
|
||
|
/// Constructs default (empty) primitive attributes.
|
||
|
primitive_attr() {
|
||
|
dnnl_primitive_attr_t result;
|
||
|
error::wrap_c_api(dnnl_primitive_attr_create(&result),
|
||
|
"could not create primitive attribute");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
|
||
|
/// handle. The resulting handle is not weak and the C handle will be
|
||
|
/// destroyed during the destruction of the C++ object.
|
||
|
///
|
||
|
/// @param attr The C API primitive attributes.
|
||
|
primitive_attr(dnnl_primitive_attr_t attr)
|
||
|
: handle<dnnl_primitive_attr_t>(attr) {}
|
||
|
|
||
|
/// Returns the scratchpad mode.
|
||
|
scratchpad_mode get_scratchpad_mode() const {
|
||
|
dnnl_scratchpad_mode_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
|
||
|
"could not get scratchpad mode primitive attribute");
|
||
|
return scratchpad_mode(result);
|
||
|
}
|
||
|
|
||
|
/// Sets scratchpad mode.
|
||
|
///
|
||
|
/// @param mode Specified scratchpad mode.
|
||
|
void set_scratchpad_mode(scratchpad_mode mode) {
|
||
|
error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
|
||
|
get(), dnnl::convert_to_c(mode)),
|
||
|
"could not set scratchpad mode primitive attribute");
|
||
|
}
|
||
|
|
||
|
/// Returns output scaling factors correspondence mask and values.
|
||
|
///
|
||
|
/// @param mask Scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the @p
|
||
|
/// scales vector. The set i-th bit indicates that a dedicated output
|
||
|
/// scaling factor is used for each index along that dimension. The
|
||
|
/// mask value of 0 implies a common output scaling factor for the
|
||
|
/// whole output tensor.
|
||
|
/// @param scales Vector of output scaling factors.
|
||
|
void get_output_scales(int &mask, std::vector<float> &scales) const {
|
||
|
dnnl_dim_t count;
|
||
|
int c_mask;
|
||
|
const float *c_scales;
|
||
|
error::wrap_c_api(dnnl_primitive_attr_get_output_scales(
|
||
|
get(), &count, &c_mask, &c_scales),
|
||
|
"could not get output scales primitive attribute");
|
||
|
scales.resize(count);
|
||
|
|
||
|
mask = c_mask;
|
||
|
for (dnnl_dim_t c = 0; c < count; ++c)
|
||
|
scales[c] = c_scales[c];
|
||
|
}
|
||
|
|
||
|
/// Sets output scaling factors correspondence mask and values.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The order of dimensions does not depend on how elements are laid
|
||
|
/// out in memory. For example:
|
||
|
/// - for a 2D CNN activations tensor the order is always (n, c)
|
||
|
/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
|
||
|
/// - for a 5D CNN weights tensor the order is always
|
||
|
/// (g, oc, ic, kh, kw)
|
||
|
///
|
||
|
/// Example usage:
|
||
|
/// @code
|
||
|
/// int mb = 32, oc = 32,
|
||
|
/// oh = 14, ow = 14; // convolution output params
|
||
|
/// // unique output scales per output channel
|
||
|
/// vector<float> scales = { ... };
|
||
|
/// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
|
||
|
///
|
||
|
/// // construct a convolution descriptor
|
||
|
/// dnnl::convolution::desc conv_d;
|
||
|
///
|
||
|
/// dnnl::primitive_attr attr;
|
||
|
/// attr.set_output_scales(attr, oc, 1 << oc_dim, scales);
|
||
|
///
|
||
|
/// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
|
||
|
/// @endcode
|
||
|
///
|
||
|
/// @param mask Defines the correspondence between the output tensor
|
||
|
/// dimensions and the @p scales vector. The set i-th bit indicates
|
||
|
/// that a dedicated scaling factor is used for each index along that
|
||
|
/// dimension. Set the mask to 0 to use a common output scaling factor
|
||
|
/// for the whole output tensor.
|
||
|
/// @param scales Constant vector of output scaling factors. If the
|
||
|
/// scaling factors are known at the time of this call, the following
|
||
|
/// equality must hold:
|
||
|
/// \f[scales.size() = \prod\limits_{d \in mask} output.dims[d].\f]
|
||
|
/// Violations can only be detected when the attributes
|
||
|
/// are used to create a primitive descriptor.
|
||
|
/// If the scaling factors are not known at the time of the call,
|
||
|
/// this vector must contain a single #DNNL_RUNTIME_F32_VAL value and
|
||
|
/// the output scaling factors must be passed at execution time as an
|
||
|
/// argument with index #DNNL_ARG_ATTR_OUTPUT_SCALES.
|
||
|
void set_output_scales(int mask, const std::vector<float> &scales) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_attr_set_output_scales(
|
||
|
get(), (dnnl_dim_t)scales.size(), mask, scales.data()),
|
||
|
"could not set output scales primitive attribute");
|
||
|
}
|
||
|
|
||
|
/// Returns scaling factors correspondence mask and values for a given
|
||
|
/// memory argument.
|
||
|
///
|
||
|
/// @param arg Parameter argument index as passed to the
|
||
|
/// primitive::execute() call.
|
||
|
/// @param mask Scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the @p
|
||
|
/// scales vector. The set i-th bit indicates that a dedicated scaling
|
||
|
/// factor is used for each index along that dimension. Set the mask to
|
||
|
/// 0 to use a common scaling factor for the whole output tensor.
|
||
|
/// @param scales Output vector of scaling factors.
|
||
|
void get_scales(int arg, int &mask, std::vector<float> &scales) const {
|
||
|
dnnl_dim_t count;
|
||
|
int c_mask;
|
||
|
const float *c_scales;
|
||
|
error::wrap_c_api(dnnl_primitive_attr_get_scales(
|
||
|
get(), arg, &count, &c_mask, &c_scales),
|
||
|
"could not get scales primitive attributes");
|
||
|
scales.resize(count);
|
||
|
|
||
|
mask = c_mask;
|
||
|
for (dnnl_dim_t c = 0; c < count; ++c)
|
||
|
scales[c] = c_scales[c];
|
||
|
}
|
||
|
|
||
|
/// Sets scaling factors for primitive operations for a given memory
|
||
|
/// argument.
|
||
|
///
|
||
|
/// @sa dnnl_primitive_attr_set_scales
|
||
|
/// @sa dnnl::primitive_attr::set_output_scales
|
||
|
///
|
||
|
/// @param arg Parameter argument index as passed to the
|
||
|
/// primitive::execute() call.
|
||
|
/// @param mask Scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the tensor dimensions and the @p scales
|
||
|
/// vector. The set i-th bit indicates that a dedicated scaling factor
|
||
|
/// is used for each index along that dimension. Set the mask to 0 to
|
||
|
/// use a common scaling factor for the whole output tensor.
|
||
|
/// @param scales Constant vector of scaling factors. The following equality
|
||
|
/// must hold:
|
||
|
/// \f[scales.size() = \prod\limits_{d \in mask} argument.dims[d].\f]
|
||
|
void set_scales(int arg, int mask, const std::vector<float> &scales) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_attr_set_scales(get(), arg,
|
||
|
(dnnl_dim_t)scales.size(), mask, scales.data()),
|
||
|
"could not set scales primitive attribute");
|
||
|
}
|
||
|
|
||
|
/// Returns zero points correspondence mask and values.
|
||
|
///
|
||
|
/// @param arg Parameter argument index as passed to the
|
||
|
/// primitive::execute() call.
|
||
|
/// @param mask Zero points correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the @p
|
||
|
/// zero_points vector. The set i-th bit indicates that a dedicated
|
||
|
/// zero point is used for each index along that dimension. Set the
|
||
|
/// mask to 0 to use a common zero point for the whole output tensor.
|
||
|
/// @param zero_points Output vector of zero points.
|
||
|
void get_zero_points(
|
||
|
int arg, int &mask, std::vector<int32_t> &zero_points) const {
|
||
|
dnnl_dim_t count;
|
||
|
int c_mask;
|
||
|
const int32_t *c_zero_points;
|
||
|
error::wrap_c_api(dnnl_primitive_attr_get_zero_points(
|
||
|
get(), arg, &count, &c_mask, &c_zero_points),
|
||
|
"could not get zero points primitive attribute");
|
||
|
zero_points.resize(count);
|
||
|
|
||
|
mask = c_mask;
|
||
|
for (dnnl_dim_t c = 0; c < count; ++c)
|
||
|
zero_points[c] = c_zero_points[c];
|
||
|
}
|
||
|
|
||
|
/// Sets zero points for primitive operations for a given memory argument.
|
||
|
///
|
||
|
/// @sa dnnl_primitive_attr_set_zero_points
|
||
|
/// @sa dnnl::primitive_attr::set_output_scales
|
||
|
///
|
||
|
/// @param arg Parameter argument index as passed to the
|
||
|
/// primitive::execute() call.
|
||
|
/// @param mask Zero point correspondence mask that defines the
|
||
|
/// correspondence between the tensor dimensions and the @p
|
||
|
/// zero_points vector. The set i-th bit indicates that a dedicated
|
||
|
/// zero point is used for each index along that dimension. Set the
|
||
|
/// mask to 0 to use a common zero point for the whole output tensor.
|
||
|
/// @param zero_points Constant vector of zero points. If the zero points
|
||
|
/// are known at the time of this call, the following equality must
|
||
|
/// hold:
|
||
|
/// \f[zero_points.size() = \prod\limits_{d \in mask} argument.dims[d].\f]
|
||
|
/// If the zero points are not known at the time of the call, this
|
||
|
/// vector must contain a single #DNNL_RUNTIME_F32_VAL value and the
|
||
|
/// zero points must be passed at execution time as an argument with
|
||
|
/// index #DNNL_ARG_ATTR_ZERO_POINTS.
|
||
|
void set_zero_points(
|
||
|
int arg, int mask, const std::vector<int32_t> &zero_points) {
|
||
|
error::wrap_c_api(dnnl_primitive_attr_set_zero_points(get(), arg,
|
||
|
(dnnl_dim_t)zero_points.size(), mask,
|
||
|
zero_points.data()),
|
||
|
"could not set zero points primitive attribute");
|
||
|
}
|
||
|
|
||
|
/// Returns post-ops previously set via set_post_ops().
|
||
|
///
|
||
|
/// @returns Post-ops.
|
||
|
const post_ops get_post_ops() const {
|
||
|
post_ops result;
|
||
|
const_dnnl_post_ops_t c_result;
|
||
|
error::wrap_c_api(dnnl_primitive_attr_get_post_ops(get(), &c_result),
|
||
|
"could not get post-ops primitive attribute");
|
||
|
result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
/// Sets post-ops.
|
||
|
///
|
||
|
/// @note
|
||
|
/// There is no way to check whether the post-ops would be supported
|
||
|
/// by the target primitive. Any error will be reported
|
||
|
/// by the respective primitive descriptor constructor.
|
||
|
///
|
||
|
/// @param ops Post-ops object to copy post-ops from.
|
||
|
void set_post_ops(const post_ops ops) {
|
||
|
error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
|
||
|
"could not set post-ops primitive attribute");
|
||
|
}
|
||
|
|
||
|
/// Sets quantization scale and shift parameters for RNN data tensors.
|
||
|
///
|
||
|
/// For performance reasons, the low-precision configuration of the RNN
|
||
|
/// primitives expect input activations to have the unsigned 8-bit integer
|
||
|
/// data type. The scale and shift parameters are used to quantize
|
||
|
/// floating-point data to unsigned integer and must be passed to the RNN
|
||
|
/// primitive using attributes.
|
||
|
///
|
||
|
/// The quantization formula is `scale * (data + shift)`.
|
||
|
///
|
||
|
/// @note
|
||
|
/// Quantization scale and shift are common for src_layer, src_iter,
|
||
|
/// dst_iter, and dst_layer.
|
||
|
///
|
||
|
/// Example usage:
|
||
|
/// @code
|
||
|
/// // RNN parameters
|
||
|
/// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
|
||
|
/// // Activations quantization parameters
|
||
|
/// float scale = ..., shift = ..;
|
||
|
///
|
||
|
/// primitive_attr attr;
|
||
|
///
|
||
|
/// // Set scale and shift for int8 quantization of activation
|
||
|
/// attr.set_rnn_data_qparams(scale, shift);
|
||
|
///
|
||
|
/// // Create and configure rnn op_desc
|
||
|
/// vanilla_rnn_forward::desc rnn_d(...);
|
||
|
/// vanilla_rnn_forward::primitive_desc rnn_d(rnn_d, attr, engine);
|
||
|
/// @endcode
|
||
|
///
|
||
|
/// @param scale The value to scale the data by.
|
||
|
/// @param shift The value to shift the data by.
|
||
|
void set_rnn_data_qparams(float scale, float shift) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
|
||
|
"could not get RNN data quantization parameters primitive "
|
||
|
"attribute");
|
||
|
}
|
||
|
|
||
|
/// Sets quantization scaling factors for RNN weights tensors. The
|
||
|
/// low-precision configuration of the RNN primitives expect input weights
|
||
|
/// to use the signed 8-bit integer data type. The scaling factors are
|
||
|
/// used to quantize floating-point data to signed integer and must be
|
||
|
/// passed to RNN primitives using attributes.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The dimension order is always native and does not depend on the
|
||
|
/// actual layout used. For example, five-dimensional weights always
|
||
|
/// have (l, d, i, g, o) logical dimension ordering.
|
||
|
///
|
||
|
/// @note
|
||
|
/// Quantization scales are common for weights_layer and
|
||
|
/// weights_iteration
|
||
|
///
|
||
|
/// @param mask Scaling factors correspondence mask that defines the
|
||
|
/// correspondence between the output tensor dimensions and the @p
|
||
|
/// scales vector. The set i-th bit indicates that a dedicated scaling
|
||
|
/// factor should be used each index along that dimension. Set the
|
||
|
/// mask to 0 to use a common scaling factor for the whole output
|
||
|
/// tensor.
|
||
|
/// @param scales Constant vector of output scaling factors. The following
|
||
|
/// equality must hold:
|
||
|
/// \f[scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f]
|
||
|
/// Violations can only be detected when the attributes are used to
|
||
|
/// create a primitive descriptor.
|
||
|
void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
|
||
|
error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
|
||
|
(int)scales.size(), mask, scales.data()),
|
||
|
"could not get RNN weights quantization parameters primitive "
|
||
|
"attribute");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_attributes
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives_common
|
||
|
/// @{
|
||
|
|
||
|
/// Base class for all primitive descriptors.
|
||
|
struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
|
||
|
using handle<dnnl_primitive_desc_t>::handle;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc_base() = default;
|
||
|
|
||
|
/// Returns the engine of the primitive descriptor.
|
||
|
/// @returns The engine of the primitive descriptor.
|
||
|
engine get_engine() const { return engine::query(*this); }
|
||
|
|
||
|
/// Returns implementation name.
|
||
|
/// @returns The implementation name.
|
||
|
const char *impl_info_str() const {
|
||
|
const char *res;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_query(
|
||
|
get(), dnnl_query_impl_info_str, 0, &res),
|
||
|
"could not retrieve implementation info string from a "
|
||
|
"primitive descriptor");
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
/// Returns a memory::dim value (same as int64_t).
|
||
|
/// @param what The value to query.
|
||
|
/// @returns The result of the query.
|
||
|
memory::dim query_s64(query what) const {
|
||
|
memory::dim res;
|
||
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
||
|
get(), dnnl::convert_to_c(what), 0, &res);
|
||
|
return status == dnnl_success ? res : 0;
|
||
|
}
|
||
|
|
||
|
/// Returns a memory descriptor.
|
||
|
///
|
||
|
/// @note
|
||
|
/// See also the convenience methods
|
||
|
/// dnnl::primitive_desc_base::src_desc(),
|
||
|
/// dnnl::primitive_desc_base::dst_desc(), and others.
|
||
|
///
|
||
|
/// @param what The kind of parameter to query; can be
|
||
|
/// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
|
||
|
/// @param idx Index of the parameter. For example, convolution bias can
|
||
|
/// be queried with what = #dnnl::query::weights_md and idx = 1.
|
||
|
/// @returns The requested memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// parameter of the specified kind or index.
|
||
|
memory::desc query_md(query what, int idx = 0) const {
|
||
|
std::vector<query> valid_q {query::src_md, query::diff_src_md,
|
||
|
query::weights_md, query::diff_weights_md, query::dst_md,
|
||
|
query::diff_dst_md, query::workspace_md, query::scratchpad_md,
|
||
|
query::exec_arg_md};
|
||
|
if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
|
||
|
[=](query q) { return what == q; }))
|
||
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
||
|
"memory descriptor query is invalid");
|
||
|
|
||
|
const dnnl_memory_desc_t *cdesc = dnnl_primitive_desc_query_md(
|
||
|
get(), dnnl::convert_to_c(what), idx);
|
||
|
return cdesc ? memory::desc(*cdesc) : memory::desc();
|
||
|
}
|
||
|
|
||
|
/// Returns a source memory descriptor.
|
||
|
/// @param idx Source index.
|
||
|
/// @returns Source memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// source parameter with index @p pdx.
|
||
|
memory::desc src_desc(int idx) const {
|
||
|
return query_md(query::src_md, idx);
|
||
|
}
|
||
|
|
||
|
/// Returns a destination memory descriptor.
|
||
|
/// @param idx Destination index.
|
||
|
/// @returns Destination memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// destination parameter with index @p pdx.
|
||
|
memory::desc dst_desc(int idx) const {
|
||
|
return query_md(query::dst_md, idx);
|
||
|
}
|
||
|
|
||
|
/// Returns a weights memory descriptor.
|
||
|
/// @param idx Weights index.
|
||
|
/// @returns Weights memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// weights parameter with index @p pdx.
|
||
|
memory::desc weights_desc(int idx) const {
|
||
|
return query_md(query::weights_md, idx);
|
||
|
}
|
||
|
|
||
|
/// Returns a diff source memory descriptor.
|
||
|
/// @param idx Diff source index.
|
||
|
/// @returns Diff source memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff source parameter with index @p pdx.
|
||
|
memory::desc diff_src_desc(int idx) const {
|
||
|
return query_md(query::diff_src_md, idx);
|
||
|
}
|
||
|
|
||
|
/// Returns a diff destination memory descriptor.
|
||
|
/// @param idx Diff destination index.
|
||
|
/// @returns Diff destination memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff destination parameter with index @p pdx.
|
||
|
memory::desc diff_dst_desc(int idx) const {
|
||
|
return query_md(query::diff_dst_md, idx);
|
||
|
}
|
||
|
|
||
|
/// Returns a diff weights memory descriptor.
|
||
|
/// @param idx Diff weights index.
|
||
|
/// @returns Diff weights memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff weights parameter with index @p pdx.
|
||
|
memory::desc diff_weights_desc(int idx) const {
|
||
|
return query_md(query::diff_weights_md, idx);
|
||
|
}
|
||
|
|
||
|
// Separate versions without the index argument for documentation
|
||
|
// purposes.
|
||
|
|
||
|
/// Returns a source memory descriptor.
|
||
|
/// @returns Source memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// source parameter.
|
||
|
memory::desc src_desc() const { return src_desc(0); }
|
||
|
|
||
|
/// Returns a destination memory descriptor.
|
||
|
/// @returns Destination memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// destination parameter.
|
||
|
memory::desc dst_desc() const { return dst_desc(0); }
|
||
|
|
||
|
/// Returns a weights memory descriptor.
|
||
|
/// @returns Weights memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// weights parameter.
|
||
|
memory::desc weights_desc() const { return weights_desc(0); }
|
||
|
|
||
|
/// Returns a diff source memory descriptor.
|
||
|
/// @returns Diff source memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff source memory with.
|
||
|
memory::desc diff_src_desc() const { return diff_src_desc(0); }
|
||
|
|
||
|
/// Returns a diff destination memory descriptor.
|
||
|
/// @returns Diff destination memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff destination parameter.
|
||
|
memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
|
||
|
|
||
|
/// Returns a diff weights memory descriptor.
|
||
|
/// @returns Diff weights memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff weights parameter.
|
||
|
memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
|
||
|
|
||
|
/// Returns the workspace memory descriptor.
|
||
|
/// @returns Workspace memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not require
|
||
|
/// workspace parameter.
|
||
|
memory::desc workspace_desc() const {
|
||
|
return query_md(query::workspace_md, 0);
|
||
|
}
|
||
|
|
||
|
/// Returns the scratchpad memory descriptor.
|
||
|
/// @returns scratchpad memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not require
|
||
|
/// scratchpad parameter.
|
||
|
/// @sa @ref dev_guide_attributes_scratchpad
|
||
|
memory::desc scratchpad_desc() const {
|
||
|
return query_md(query::scratchpad_md, 0);
|
||
|
}
|
||
|
|
||
|
/// Returns the engine on which the scratchpad memory is located.
|
||
|
/// @returns The engine on which the scratchpad memory is located.
|
||
|
engine scratchpad_engine() const {
|
||
|
dnnl_engine_t c_engine;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_query(get(),
|
||
|
dnnl::convert_to_c(query::scratchpad_engine),
|
||
|
0, &c_engine),
|
||
|
"could not retrieve scratchpad engine from a primitive "
|
||
|
"descriptor");
|
||
|
return engine(c_engine, true);
|
||
|
}
|
||
|
|
||
|
/// Returns the primitive attributes.
|
||
|
/// @returns The primitive attributes.
|
||
|
primitive_attr get_primitive_attr() const {
|
||
|
const_dnnl_primitive_attr_t const_c_attr;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
|
||
|
"could not get attributes from a primitive descriptor");
|
||
|
dnnl_primitive_attr_t c_attr;
|
||
|
error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
|
||
|
"could not clone primitive attributes");
|
||
|
return primitive_attr(c_attr);
|
||
|
}
|
||
|
|
||
|
/// Returns the kind of the primitive descriptor.
|
||
|
/// @returns The kind of the primitive descriptor.
|
||
|
dnnl::primitive::kind get_kind() const {
|
||
|
dnnl_primitive_kind_t kind;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_query(get(),
|
||
|
dnnl_query_primitive_kind, 0, (void *)&kind),
|
||
|
"could not get primitive kind from a primitive descriptor");
|
||
|
return static_cast<dnnl::primitive::kind>(kind);
|
||
|
}
|
||
|
|
||
|
protected:
|
||
|
/// Resets the value of the handle to a clone of a C API primitive
|
||
|
/// descriptor.
|
||
|
/// @param pd A C API primitive descriptor to clone.
|
||
|
void reset_with_clone(const_dnnl_primitive_desc_t pd) {
|
||
|
dnnl_primitive_desc_t new_pd;
|
||
|
error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
|
||
|
"could not clone a primitive descriptor");
|
||
|
reset(new_pd);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor base object from a clone of a C API
|
||
|
/// primitive descriptor after verifying that it is what the caller
|
||
|
/// expects.
|
||
|
///
|
||
|
/// @note
|
||
|
/// The @p prim_kind should map to a primitive that does not have
|
||
|
/// different values of propagation kind (e.g. #dnnl::binary).
|
||
|
/// @note
|
||
|
/// Primitive descriptor base constructed this way does not support
|
||
|
/// next_impl() (will throw).
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor to clone.
|
||
|
/// @param prim_kind Expected primitive kind.
|
||
|
primitive_desc_base(
|
||
|
dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
|
||
|
: primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor base object from a clone of a C API
|
||
|
/// primitive descriptor after verifying that it is what the caller
|
||
|
/// expects.
|
||
|
///
|
||
|
/// @note
|
||
|
/// Primitive descriptor base constructed this way does not support
|
||
|
/// next_impl() (will throw).
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor to clone.
|
||
|
/// @param prim_kind Expected primitive kind.
|
||
|
/// @param prop_kind Expected propagation kind.
|
||
|
primitive_desc_base(dnnl_primitive_desc_t pd,
|
||
|
dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind)
|
||
|
: primitive_desc_base(pd, prim_kind, prop_kind, prop_kind) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor base object from a clone of a C API
|
||
|
/// primitive descriptor after verifying that it is what the caller
|
||
|
/// expects.
|
||
|
///
|
||
|
/// @note
|
||
|
/// Primitive descriptor base constructed this way does not support
|
||
|
/// next_impl() (will throw).
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor to clone.
|
||
|
/// @param prim_kind Expected primitive kind.
|
||
|
/// @param prop_kind1 Expected propagation kind (option 1).
|
||
|
/// @param prop_kind2 Expected propagation kind (option 2). This value is
|
||
|
/// checked if the check with @p prop_kind1 fails.
|
||
|
primitive_desc_base(dnnl_primitive_desc_t pd,
|
||
|
dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
|
||
|
dnnl::prop_kind prop_kind2) {
|
||
|
// It is OK to pass an empty primitive descriptor
|
||
|
if (pd == nullptr) return;
|
||
|
|
||
|
dnnl_status_t rc;
|
||
|
|
||
|
dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
|
||
|
dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
|
||
|
dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
|
||
|
|
||
|
// Check that primitive kind matches
|
||
|
dnnl_primitive_kind_t pd_kind;
|
||
|
rc = dnnl_primitive_desc_query(
|
||
|
pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
|
||
|
error::wrap_c_api(
|
||
|
rc, "could not get primitive kind from a primitive descriptor");
|
||
|
if (pd_kind != c_prim_kind)
|
||
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
||
|
"primitive descriptor operation kind mismatch");
|
||
|
|
||
|
// Check that propagation kind matches
|
||
|
dnnl_prop_kind_t pd_prop_kind;
|
||
|
rc = dnnl_primitive_desc_query(
|
||
|
pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
|
||
|
|
||
|
// Something went wrong
|
||
|
if (rc != dnnl_success && rc != dnnl_unimplemented)
|
||
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
||
|
"could not get propagation kind from the primitive "
|
||
|
"descriptor");
|
||
|
|
||
|
// Everything is fine
|
||
|
if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
|
||
|
|| (rc == dnnl_success
|
||
|
&& (pd_prop_kind == c_prop_kind1
|
||
|
|| pd_prop_kind == c_prop_kind2))) {
|
||
|
reset_with_clone(pd);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// We could get the propagation kind but there is a mismatch
|
||
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
||
|
"primitive descriptor propagation kind mismatch");
|
||
|
}
|
||
|
|
||
|
using base = primitive_desc_base;
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_primitives_common
|
||
|
|
||
|
/// @addtogroup dnnl_api_reorder Reorder
|
||
|
///
|
||
|
/// A primitive to copy data between two memory objects. This primitive is
|
||
|
/// typically used to change the way the data is laid out in memory.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_reorder in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Reorder primitive.
|
||
|
struct reorder : public primitive {
|
||
|
/// Primitive descriptor for a reorder primitive.
|
||
|
struct primitive_desc : public primitive_desc_base {
|
||
|
using primitive_desc_base::primitive_desc_base;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for reorder primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param src_engine Engine on which the source memory object will be
|
||
|
/// located.
|
||
|
/// @param src_md Source memory descriptor.
|
||
|
/// @param dst_engine Engine on which the destination memory object
|
||
|
/// will be located.
|
||
|
/// @param dst_md Destination memory descriptor.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
primitive_desc(const engine &src_engine, const memory::desc &src_md,
|
||
|
const engine &dst_engine, const memory::desc &dst_md,
|
||
|
const primitive_attr &attr = primitive_attr()) {
|
||
|
dnnl_primitive_desc_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_reorder_primitive_desc_create(&result, &src_md.data,
|
||
|
src_engine.get(), &dst_md.data, dst_engine.get(),
|
||
|
attr.get()),
|
||
|
"could not create a primitive descriptor for a reorder "
|
||
|
"primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor for reorder primitive.
|
||
|
///
|
||
|
/// @param src Source memory object. It is used to obtain the source
|
||
|
/// memory descriptor and engine.
|
||
|
/// @param dst Destination memory object. It is used to obtain the
|
||
|
/// destination memory descriptor and engine.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
primitive_desc(const memory &src, const memory &dst,
|
||
|
const primitive_attr &attr = primitive_attr()) {
|
||
|
dnnl_primitive_desc_t result;
|
||
|
auto src_md = src.get_desc();
|
||
|
auto dst_md = dst.get_desc();
|
||
|
error::wrap_c_api(
|
||
|
dnnl_reorder_primitive_desc_create(&result, &src_md.data,
|
||
|
src.get_engine().get(), &dst_md.data,
|
||
|
dst.get_engine().get(), attr.get()),
|
||
|
"could not create a primitive descriptor for a reorder "
|
||
|
"primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor for reorder primitive from a C
|
||
|
/// API primitive descriptor which must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for reorder primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
|
||
|
|
||
|
/// Returns the engine on which the source memory is allocated.
|
||
|
/// @returns The engine on which the source memory is allocated.
|
||
|
engine get_src_engine() const {
|
||
|
return engine::query(*this, dnnl::query::reorder_src_engine);
|
||
|
}
|
||
|
|
||
|
/// Returns the engine on which the destination memory is allocated.
|
||
|
/// @returns The engine on which the destination memory is allocated.
|
||
|
engine get_dst_engine() const {
|
||
|
return engine::query(*this, dnnl::query::reorder_dst_engine);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
reorder() = default;
|
||
|
|
||
|
/// Constructs a reorder primitive.
|
||
|
/// @param pd Primitive descriptor for reorder primitive.
|
||
|
reorder(const primitive_desc &pd) : primitive(pd.get()) {}
|
||
|
|
||
|
/// Constructs a reorder primitive that would reorder data between memory
|
||
|
/// objects having the same memory descriptors as memory objects @p src and
|
||
|
/// @p dst.
|
||
|
///
|
||
|
/// @param src Source memory object.
|
||
|
/// @param dst Destination memory object.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
reorder(const memory &src, const memory &dst,
|
||
|
const primitive_attr &attr = primitive_attr())
|
||
|
: primitive(primitive_desc(src, dst, attr).get()) {}
|
||
|
|
||
|
using primitive::execute;
|
||
|
|
||
|
/// Executes the reorder primitive.
|
||
|
///
|
||
|
/// @param stream Stream object. The stream must belong to the same engine
|
||
|
/// as the primitive.
|
||
|
/// @param src Source memory object.
|
||
|
/// @param dst Destination memory object.
|
||
|
void execute(const stream &stream, memory &src, memory &dst) const {
|
||
|
primitive::execute(stream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_reorder
|
||
|
|
||
|
/// @addtogroup dnnl_api_concat Concat
|
||
|
///
|
||
|
/// A primitive to concatenate data by arbitrary dimension.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_concat in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
inline std::vector<dnnl_memory_desc_t> convert_to_c(
|
||
|
const std::vector<memory::desc> &mems) {
|
||
|
std::vector<dnnl_memory_desc_t> c_mems;
|
||
|
c_mems.reserve(mems.size());
|
||
|
for (const auto &s : mems)
|
||
|
c_mems.push_back(s.data);
|
||
|
return c_mems;
|
||
|
}
|
||
|
/// @endcond
|
||
|
|
||
|
/// Tensor concatenation (concat) primitive.
|
||
|
struct concat : public primitive {
|
||
|
/// Primitive descriptor for a concat primitive.
|
||
|
struct primitive_desc : public primitive_desc_base {
|
||
|
using primitive_desc_base::primitive_desc_base;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an out-of-place concatenation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src[0]` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src[1]` (#dnnl::primitive_desc_base::src_desc(`1`))
|
||
|
/// - ...
|
||
|
/// - `src[n - 1]` (#dnnl::primitive_desc_base::src_desc(`n - 1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param dst Destination memory descriptor.
|
||
|
/// @param concat_dimension Source tensors will be concatenated over
|
||
|
/// dimension with this index. Note that order of dimensions does
|
||
|
/// not depend on memory format.
|
||
|
/// @param srcs Vector of source memory descriptors.
|
||
|
/// @param engine Engine to perform the operation on.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
primitive_desc(const memory::desc &dst, int concat_dimension,
|
||
|
const std::vector<memory::desc> &srcs, const engine &engine,
|
||
|
const primitive_attr &attr = primitive_attr()) {
|
||
|
auto c_srcs = convert_to_c(srcs);
|
||
|
|
||
|
dnnl_primitive_desc_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_concat_primitive_desc_create(&result, &dst.data,
|
||
|
(int)c_srcs.size(), concat_dimension, c_srcs.data(),
|
||
|
attr.get(), engine.get()),
|
||
|
"could not create a primitive descriptor for a concat "
|
||
|
"primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an out-of-place concatenation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// This version derives the destination memory descriptor
|
||
|
/// automatically.
|
||
|
///
|
||
|
/// @param concat_dimension Source tensors will be concatenated over
|
||
|
/// dimension with this index. Note that order of dimensions does
|
||
|
/// not depend on memory format.
|
||
|
/// @param srcs Vector of source memory descriptors.
|
||
|
/// @param engine Engine to perform the operation on.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
primitive_desc(int concat_dimension,
|
||
|
const std::vector<memory::desc> &srcs, const engine &engine,
|
||
|
const primitive_attr &attr = primitive_attr()) {
|
||
|
auto c_api_srcs = convert_to_c(srcs);
|
||
|
|
||
|
dnnl_primitive_desc_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_concat_primitive_desc_create(&result, nullptr,
|
||
|
(int)c_api_srcs.size(), concat_dimension,
|
||
|
c_api_srcs.data(), attr.get(), engine.get()),
|
||
|
"could not create a primitive descriptor for a concat "
|
||
|
"primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor for concat primitive from a C
|
||
|
/// API primitive descriptor which must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for concat primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
|
||
|
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
concat() = default;
|
||
|
|
||
|
/// Constructs a concatenation primitive.
|
||
|
/// @param pd Primitive descriptor for concatenation primitive.
|
||
|
concat(const primitive_desc &pd) : primitive(pd.get()) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_concat
|
||
|
|
||
|
/// @addtogroup dnnl_api_sum Sum
|
||
|
///
|
||
|
/// A primitive to sum multiple tensors.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_sum in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Out-of-place summation (sum) primitive.
|
||
|
struct sum : public primitive {
|
||
|
/// Primitive descriptor for a sum primitive.
|
||
|
struct primitive_desc : public primitive_desc_base {
|
||
|
using primitive_desc_base::primitive_desc_base;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a sum primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src[0]` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src[1]` (#dnnl::primitive_desc_base::src_desc(`1`))
|
||
|
/// - ...
|
||
|
/// - `src[n - 1]` (#dnnl::primitive_desc_base::src_desc(`n - 1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param dst Destination memory descriptor.
|
||
|
/// @param scales Vector of scales to multiply data in each source
|
||
|
/// memory by.
|
||
|
/// @param srcs Vector of source memory descriptors.
|
||
|
/// @param engine Engine to perform the operation on.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
primitive_desc(const memory::desc &dst,
|
||
|
const std::vector<float> &scales,
|
||
|
const std::vector<memory::desc> &srcs, const engine &engine,
|
||
|
const primitive_attr &attr = primitive_attr()) {
|
||
|
validate_container_size(scales,
|
||
|
"counts of scales and sources are not equal",
|
||
|
(int)srcs.size(), (int)srcs.size());
|
||
|
|
||
|
auto c_api_srcs = convert_to_c(srcs);
|
||
|
|
||
|
dnnl_primitive_desc_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_sum_primitive_desc_create(&result, &dst.data,
|
||
|
(int)c_api_srcs.size(), scales.data(),
|
||
|
c_api_srcs.data(), attr.get(), engine.get()),
|
||
|
"could not create a primitive descriptor for a sum "
|
||
|
"primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a sum primitive.
|
||
|
///
|
||
|
/// This version derives the destination memory descriptor
|
||
|
/// automatically.
|
||
|
///
|
||
|
/// @param scales Vector of scales by which to multiply data in each
|
||
|
/// source memory object.
|
||
|
/// @param srcs Vector of source memory descriptors.
|
||
|
/// @param engine Engine on which to perform the operation.
|
||
|
/// @param attr Primitive attributes to use (optional).
|
||
|
primitive_desc(const std::vector<float> &scales,
|
||
|
const std::vector<memory::desc> &srcs, const engine &engine,
|
||
|
const primitive_attr &attr = primitive_attr()) {
|
||
|
validate_container_size(scales,
|
||
|
"counts of scales and sources are not equal",
|
||
|
(int)srcs.size(), (int)srcs.size());
|
||
|
|
||
|
auto c_api_srcs = convert_to_c(srcs);
|
||
|
dnnl_primitive_desc_t result;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_sum_primitive_desc_create(&result, nullptr,
|
||
|
(int)c_api_srcs.size(), scales.data(),
|
||
|
c_api_srcs.data(), attr.get(), engine.get()),
|
||
|
"could not create a primitive descriptor for a sum "
|
||
|
"primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
/// Constructs a primitive descriptor for sum primitive from a C API
|
||
|
/// primitive descriptor which must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for reorder primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
|
||
|
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
sum() = default;
|
||
|
|
||
|
/// Constructs a sum primitive.
|
||
|
/// @param pd Primitive descriptor for sum primitive.
|
||
|
sum(const primitive_desc &pd) : primitive(pd.get()) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_sum
|
||
|
|
||
|
/// @addtogroup dnnl_api_primitives_common
|
||
|
/// @{
|
||
|
|
||
|
/// A base class for descriptors of all primitives that have an operation
|
||
|
/// descriptor and that support iteration over multiple implementations.
|
||
|
struct primitive_desc : public primitive_desc_base {
|
||
|
using primitive_desc_base::primitive_desc_base;
|
||
|
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor.
|
||
|
///
|
||
|
/// @note
|
||
|
/// If @p allow_empty is true, the constructor does not throw if a
|
||
|
/// primitive descriptor cannot be created. But calling next_impl() in
|
||
|
/// this case will throw.
|
||
|
///
|
||
|
/// @note
|
||
|
/// This is a low-level implementation detail that is typically not
|
||
|
/// needed in application code.
|
||
|
///
|
||
|
/// @param desc Constant C API operation descriptor.
|
||
|
/// @param attr Pointer to primitive attributes. It is safe to pass
|
||
|
/// nullptr to indicate absence of attributes.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd C API primitive descriptor for a forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use for backward propagation or weights gradient.
|
||
|
/// @param allow_empty A flag signifying whether construction is allowed
|
||
|
/// to fail without throwing an exception. In this case an empty
|
||
|
/// object will be produced. This flag is optional and defaults to
|
||
|
/// false.
|
||
|
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr,
|
||
|
const engine &engine, const_dnnl_primitive_desc_t hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: allow_empty_(allow_empty) {
|
||
|
dnnl_primitive_desc_iterator_t iterator = nullptr;
|
||
|
dnnl_status_t status = dnnl_primitive_desc_iterator_create(&iterator,
|
||
|
desc, attr ? attr->get() : nullptr, engine.get(), hint_fwd_pd);
|
||
|
if (!allow_empty)
|
||
|
error::wrap_c_api(
|
||
|
status, "could not create a primitive descriptor iterator");
|
||
|
pd_iterator.reset(iterator);
|
||
|
fetch_impl();
|
||
|
}
|
||
|
|
||
|
/// Advances the primitive iterator to the next implementation.
|
||
|
///
|
||
|
/// @returns @c true on success, and @c false if the last implementation
|
||
|
/// reached, and the primitive descriptor itself is kept unchanged
|
||
|
bool next_impl() {
|
||
|
dnnl_status_t status
|
||
|
= dnnl_primitive_desc_iterator_next(pd_iterator.get());
|
||
|
if (status == dnnl_iterator_ends) return false;
|
||
|
error::wrap_c_api(
|
||
|
status, "could not advance a primitive descriptor iterator");
|
||
|
fetch_impl();
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
bool allow_empty_ = false;
|
||
|
handle<dnnl_primitive_desc_iterator_t> pd_iterator;
|
||
|
void fetch_impl() {
|
||
|
dnnl_primitive_desc_t pd = dnnl_primitive_desc_iterator_fetch(
|
||
|
pd_iterator.get(allow_empty_));
|
||
|
error::wrap_c_api(pd != nullptr || allow_empty_ ? dnnl_success
|
||
|
: dnnl_runtime_error,
|
||
|
"could not fetch a primitive descriptor from a primitive "
|
||
|
"descriptor iterator");
|
||
|
reset(pd);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_primitives_common
|
||
|
|
||
|
/// @addtogroup dnnl_api_convolution Convolution
|
||
|
///
|
||
|
/// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
|
||
|
/// forward propagation, backward propagation, and weights gradient with or
|
||
|
/// without bias.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_convolution in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Convolution forward propagation primitive.
|
||
|
struct convolution_forward : public primitive {
|
||
|
/// Descriptor for a convolution forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_convolution_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a convolution forward propagation
|
||
|
/// primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
||
|
/// descriptor disables the bias term.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &bias_desc, const memory::desc &dst_desc,
|
||
|
const memory::dims &strides, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_convolution_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, &bias_desc.data, &dst_desc.data,
|
||
|
&strides[0], &padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a convolution forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a convolution forward propagation
|
||
|
/// primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_convolution_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, nullptr, &dst_desc.data,
|
||
|
&strides[0], &padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a convolution forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated convolution forward
|
||
|
/// propagation primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
||
|
/// descriptor disables the bias term.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &bias_desc, const memory::desc &dst_desc,
|
||
|
const memory::dims &strides, const memory::dims &dilates,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_dilated_convolution_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, &bias_desc.data,
|
||
|
&dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated convolution "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated convolution forward
|
||
|
/// propagation primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_dilated_convolution_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, nullptr,
|
||
|
&dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated convolution "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a convolution forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a convolution forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a convolution forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a convolution forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// Returns the bias memory descriptor.
|
||
|
/// @returns The bias memory descriptor.
|
||
|
/// @returns A zero memory descriptor of the primitive does not have a
|
||
|
/// bias parameter.
|
||
|
memory::desc bias_desc() const { return base::weights_desc(1); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
convolution_forward() = default;
|
||
|
|
||
|
/// Constructs a convolution forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a convolution forward propagation
|
||
|
/// primitive.
|
||
|
convolution_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Convolution backward propagation primitive.
|
||
|
struct convolution_backward_data : public primitive {
|
||
|
|
||
|
/// Descriptor for a convolution backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_convolution_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a convolution backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &diff_src_desc,
|
||
|
const memory::desc &weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_convolution_backward_data_desc_init(&data,
|
||
|
convert_to_c(algorithm), &diff_src_desc.data,
|
||
|
&weights_desc.data, &diff_dst_desc.data,
|
||
|
&strides[0], &padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a convolution backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for dilated convolution backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// Memory descriptors are allowed to be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &diff_src_desc,
|
||
|
const memory::desc &weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_dilated_convolution_backward_data_desc_init(&data,
|
||
|
convert_to_c(algorithm), &diff_src_desc.data,
|
||
|
&weights_desc.data, &diff_dst_desc.data,
|
||
|
&strides[0], &dilates[0], &padding_l[0],
|
||
|
&padding_r[0]),
|
||
|
"could not create a descriptor for a dilated convolution "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a convolution backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a convolution backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to perform the operation on.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a convolution backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to perform the operation on.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a convolution backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
convolution_backward_data() = default;
|
||
|
|
||
|
/// Constructs a convolution backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a convolution backward propagation
|
||
|
/// primitive.
|
||
|
convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Convolution weights gradient primitive.
|
||
|
struct convolution_backward_weights : public primitive {
|
||
|
/// Descriptor for a convolution weights gradient primitive.
|
||
|
struct desc {
|
||
|
dnnl_convolution_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a convolution weights gradient primitive
|
||
|
/// with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
||
|
/// memory descriptor disables the bias term.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_convolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_desc.data, &strides[0], &padding_l[0],
|
||
|
&padding_r[0]),
|
||
|
"could not create a descriptor for a convolution weights "
|
||
|
"update primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a convolution weights gradient primitive
|
||
|
/// without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_convolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, nullptr,
|
||
|
&diff_dst_desc.data, &strides[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a convolution weights "
|
||
|
"update primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated convolution weights gradient
|
||
|
/// primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
||
|
/// memory descriptor disables the bias term.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_dilated_convolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated convolution "
|
||
|
"weights gradient primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated convolution weights gradient
|
||
|
/// primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Convolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd, and
|
||
|
/// #dnnl::algorithm::convolution_auto.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_dilated_convolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, nullptr,
|
||
|
&diff_dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated convolution "
|
||
|
"weights gradient primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a convolution weights gradient primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution weights gradient
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a convolution weights gradient primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution weights gradient
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a convolution weights gradient primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a convolution weights gradient
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a convolution weights
|
||
|
/// gradient primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
|
||
|
dnnl::prop_kind::backward_weights) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
||
|
memory::desc diff_weights_desc() const {
|
||
|
return base::diff_weights_desc(0);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// Returns the diff bias memory descriptor.
|
||
|
/// @returns The diff bias memory descriptor.
|
||
|
/// @returns A zero memory descriptor of the primitive does not have a
|
||
|
/// diff bias parameter.
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return base::diff_weights_desc(1);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
convolution_backward_weights() = default;
|
||
|
|
||
|
/// Constructs a convolution weights gradient primitive.
|
||
|
/// @param pd Primitive descriptor for a convolution weights gradient
|
||
|
/// primitive.
|
||
|
convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_convolution
|
||
|
//
|
||
|
/// @addtogroup dnnl_api_deconvolution Deconvolution
|
||
|
///
|
||
|
/// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
|
||
|
/// forward propagation, backward propagation, and weights gradient with or
|
||
|
/// without bias.
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Deconvolution forward propagation primitive.
|
||
|
struct deconvolution_forward : public primitive {
|
||
|
/// Descriptor for a deconvolution forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_deconvolution_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a deconvolution forward propagation
|
||
|
/// primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Deconvolution algorithm:
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
||
|
/// descriptor disables the bias term.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Vector of strides for spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &bias_desc, const memory::desc &dst_desc,
|
||
|
const memory::dims &strides, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_deconvolution_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, &bias_desc.data, &dst_desc.data,
|
||
|
&strides[0], &padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a deconvolution forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a deconvolution forward propagation
|
||
|
/// primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Deconvolution algorithm:
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Vector of strides for spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_deconvolution_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, nullptr, &dst_desc.data,
|
||
|
&strides[0], &padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a deconvolution forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated deconvolution forward
|
||
|
/// propagation primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Deconvolution algorithm:
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
||
|
/// descriptor disables the bias term.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Vector of strides for spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &bias_desc, const memory::desc &dst_desc,
|
||
|
const memory::dims &strides, const memory::dims &dilates,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_dilated_deconvolution_forward_desc_init(
|
||
|
&data, dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, &bias_desc.data,
|
||
|
&dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated deconvolution "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated deconvolution forward
|
||
|
/// propagation primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Deconvolution algorithm:
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Vector of strides for spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_dilated_deconvolution_forward_desc_init(
|
||
|
&data, dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&weights_desc.data, nullptr,
|
||
|
&dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated deconvolution "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a deconvolution forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a deconvolution forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a deconvolution forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
|
||
|
memory::desc bias_desc() const { return base::weights_desc(1); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
deconvolution_forward() = default;
|
||
|
|
||
|
/// Constructs a deconvolution forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a deconvolution forward propagation
|
||
|
/// primitive.
|
||
|
deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Deconvolution backward propagation primitive.
|
||
|
struct deconvolution_backward_data : public primitive {
|
||
|
/// Descriptor for a deconvolution backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_deconvolution_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a deconvolution backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Deconvolution algorithm
|
||
|
/// (#dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd).
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &diff_src_desc,
|
||
|
const memory::desc &weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_deconvolution_backward_data_desc_init(&data,
|
||
|
convert_to_c(algorithm), &diff_src_desc.data,
|
||
|
&weights_desc.data, &diff_dst_desc.data,
|
||
|
&strides[0], &padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a deconvolution "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated deconvolution backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Deconvolution algorithm
|
||
|
/// (#dnnl::algorithm::convolution_direct,
|
||
|
/// #dnnl::algorithm::convolution_winograd).
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param weights_desc Weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &diff_src_desc,
|
||
|
const memory::desc &weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_dilated_deconvolution_backward_data_desc_init(&data,
|
||
|
convert_to_c(algorithm), &diff_src_desc.data,
|
||
|
&weights_desc.data, &diff_dst_desc.data,
|
||
|
&strides[0], &dilates[0], &padding_l[0],
|
||
|
&padding_r[0]),
|
||
|
"could not create a descriptor for a dilated deconvolution "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a deconvolution backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc descriptor for a deconvolution backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc descriptor for a deconvolution backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a deconvolution backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
deconvolution_backward_data() = default;
|
||
|
|
||
|
/// Constructs a deconvolution backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a deconvolution backward propagation
|
||
|
/// primitive.
|
||
|
deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Deconvolution weights gradient primitive.
|
||
|
struct deconvolution_backward_weights : public primitive {
|
||
|
/// Descriptor for a deconvolution weights gradient primitive.
|
||
|
struct desc {
|
||
|
dnnl_deconvolution_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a deconvolution weights gradient
|
||
|
/// primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Deconvolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
||
|
/// memory descriptor disables the bias term.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_deconvolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_desc.data, &strides[0], &padding_l[0],
|
||
|
&padding_r[0]),
|
||
|
"could not create a descriptor for a deconvolution weights "
|
||
|
"update primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a deconvolution weights gradient primitive
|
||
|
/// without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Deconvolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_deconvolution_backward_weights_desc_init(
|
||
|
&data, convert_to_c(algorithm),
|
||
|
&src_desc.data, &diff_weights_desc.data,
|
||
|
nullptr, &diff_dst_desc.data, &strides[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a deconvolution weights "
|
||
|
"update primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated deconvolution weights gradient
|
||
|
/// primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Deconvolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
||
|
/// memory descriptor disables the bias term.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_dilated_deconvolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated deconvolution "
|
||
|
"weights gradient primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a dilated deconvolution weights gradient
|
||
|
/// primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param algorithm Deconvolution algorithm. Possible values are
|
||
|
/// #dnnl::algorithm::deconvolution_direct, and
|
||
|
/// #dnnl::algorithm::deconvolution_winograd.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Strides for each spatial dimension.
|
||
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
||
|
/// means no dilation in the corresponding dimension.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &dilates, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(dilates, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_dilated_deconvolution_backward_weights_desc_init(&data,
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&diff_weights_desc.data, nullptr,
|
||
|
&diff_dst_desc.data, &strides[0], &dilates[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a dilated deconvolution "
|
||
|
"weights gradient primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a deconvolution weights gradient primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution weights
|
||
|
/// update primitive.
|
||
|
///
|
||
|
/// @param desc descriptor for a deconvolution weights gradient
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case
|
||
|
/// an empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution weights
|
||
|
/// update primitive.
|
||
|
///
|
||
|
/// @param desc descriptor for a deconvolution weights gradient
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a deconvolution weights
|
||
|
/// gradient primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a deconvolution weights
|
||
|
/// gradient primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
|
||
|
dnnl::prop_kind::backward_weights) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
||
|
memory::desc diff_weights_desc() const {
|
||
|
return base::diff_weights_desc(0);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return base::diff_weights_desc(1);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
deconvolution_backward_weights() = default;
|
||
|
|
||
|
/// Constructs a deconvolution weights gradient primitive.
|
||
|
/// @param pd Primitive descriptor for a deconvolution weights gradient
|
||
|
/// primitive.
|
||
|
deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_deconvolution
|
||
|
|
||
|
/// @addtogroup dnnl_api_lrn LRN
|
||
|
///
|
||
|
/// A primitive to perform local response normalization (LRN) across or within
|
||
|
/// channels.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_lrn in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Local response normalization (LRN) forward propagation primitive.
|
||
|
struct lrn_forward : public primitive {
|
||
|
/// Descriptor for an LRN forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_lrn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a LRN forward propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p alg_kind = #dnnl::algorithm::pooling_max and @p
|
||
|
/// prop_kind = #dnnl::prop_kind::forward_training; must be
|
||
|
/// queried for using @ref dnnl::primitive_desc_base::query_md()
|
||
|
/// after a corresponding primitive descriptor is created
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm LRN algorithm kind: either
|
||
|
/// #dnnl::algorithm::lrn_across_channels, or
|
||
|
/// #dnnl::algorithm::lrn_within_channel.
|
||
|
/// @param data_desc Source and destination memory descriptors.
|
||
|
/// @param local_size Regularization local size.
|
||
|
/// @param alpha The alpha regularization parameter.
|
||
|
/// @param beta The beta regularization parameter.
|
||
|
/// @param k The k regularization parameter.
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &data_desc, memory::dim local_size,
|
||
|
float alpha, float beta, float k = 1.f) {
|
||
|
error::wrap_c_api(dnnl_lrn_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &data_desc.data,
|
||
|
local_size, alpha, beta, k),
|
||
|
"could not create a descriptor for a lrn forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an LRN forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LRN forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LRN forward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LRN forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LRN forward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LRN forward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an LRN forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
lrn_forward() = default;
|
||
|
|
||
|
/// Constructs an LRN forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an LRN forward propagation
|
||
|
/// primitive.
|
||
|
lrn_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Local response normalization (LRN) backward propagation primitive.
|
||
|
struct lrn_backward : public primitive {
|
||
|
/// Descriptor for an LRN backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_lrn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an LRN backward propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if the underlying implementation requires it; must be queried
|
||
|
/// for using @ref dnnl_primitive_desc_query_md() after a
|
||
|
/// corresponding primitive descriptor is created
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param algorithm LRN algorithm kind: either
|
||
|
/// #dnnl::algorithm::lrn_across_channels, or
|
||
|
/// #dnnl::algorithm::lrn_within_channel.
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptor.
|
||
|
/// @param data_desc Source memory descriptor.
|
||
|
/// @param local_size Regularization local size.
|
||
|
/// @param alpha The alpha regularization parameter.
|
||
|
/// @param beta The beta regularization parameter.
|
||
|
/// @param k The k regularization parameter.
|
||
|
desc(algorithm algorithm, const memory::desc &data_desc,
|
||
|
const memory::desc &diff_data_desc, memory::dim local_size,
|
||
|
float alpha, float beta, float k = 1.f) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lrn_backward_desc_init(&data, convert_to_c(algorithm),
|
||
|
&diff_data_desc.data, &data_desc.data, local_size,
|
||
|
alpha, beta, k),
|
||
|
"could not create a descriptor for a lrn backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an LRN backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LRN backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LRN backward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an LRN forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const lrn_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LRN backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LRN backward propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an LRN forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const lrn_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LRN backward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an LRN backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
lrn_backward() = default;
|
||
|
|
||
|
/// Constructs an LRN backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an LRN backward propagation
|
||
|
/// primitive.
|
||
|
lrn_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_lrn
|
||
|
|
||
|
/// @addtogroup dnnl_api_pooling Pooling
|
||
|
///
|
||
|
/// A primitive to perform max or average pooling.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_pooling in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Pooling forward propagation primitive.
|
||
|
struct pooling_forward : public primitive {
|
||
|
/// Descriptor for a pooling forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_pooling_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for pooling forward propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p alg_kind = #dnnl::algorithm::pooling_max and @p
|
||
|
/// prop_kind = #dnnl::prop_kind::forward_training; must be
|
||
|
/// queried for using @ref dnnl::primitive_desc_base::query_md()
|
||
|
/// after a corresponding primitive descriptor is created
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Pooling algorithm kind: either
|
||
|
/// #dnnl::algorithm::pooling_max,
|
||
|
/// #dnnl::algorithm::pooling_avg_include_padding,
|
||
|
/// or #dnnl::algorithm::pooling_avg (same as
|
||
|
/// #dnnl::algorithm::pooling_avg_exclude_padding).
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
/// @param strides Vector of strides for spatial dimension.
|
||
|
/// @param kernel Vector of kernel spatial dimensions.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
||
|
const memory::dims &strides, const memory::dims &kernel,
|
||
|
const memory::dims &padding_l, const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(kernel, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_pooling_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &src_desc.data,
|
||
|
&dst_desc.data, &strides[0], &kernel[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a pooling forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a pooling forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a pooling forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a pooling forward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a pooling forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a pooling forward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a pooling forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a pooling forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
pooling_forward() = default;
|
||
|
|
||
|
/// Constructs a pooling forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a pooling forward propagation
|
||
|
/// primitive.
|
||
|
pooling_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Pooling backward propagation primitive.
|
||
|
struct pooling_backward : public primitive {
|
||
|
/// Descriptor for a pooling backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_pooling_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for pooling backward propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p alg_kind = #dnnl::algorithm::pooling_max; must be
|
||
|
/// queried for using @ref dnnl::primitive_desc_base::query_md()
|
||
|
/// after a corresponding primitive descriptor is created
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param algorithm Pooling algorithm kind: either
|
||
|
/// #dnnl::algorithm::pooling_max,
|
||
|
/// #dnnl::algorithm::pooling_avg_include_padding,
|
||
|
/// or #dnnl::algorithm::pooling_avg (same as
|
||
|
/// #dnnl::algorithm::pooling_avg_exclude_padding).
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
/// @param strides Vector of strides for spatial dimension.
|
||
|
/// @param kernel Vector of kernel spatial dimensions.
|
||
|
/// @param padding_l Vector of padding values for low indices for each
|
||
|
/// spatial dimension (front, top, left).
|
||
|
/// @param padding_r Vector of padding values for high indices for
|
||
|
/// each spatial dimension (back, bottom, right).
|
||
|
desc(algorithm algorithm, const memory::desc &diff_src_desc,
|
||
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
||
|
const memory::dims &kernel, const memory::dims &padding_l,
|
||
|
const memory::dims &padding_r) {
|
||
|
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
|
||
|
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(
|
||
|
dnnl_pooling_backward_desc_init(&data,
|
||
|
convert_to_c(algorithm), &diff_src_desc.data,
|
||
|
&diff_dst_desc.data, &strides[0], &kernel[0],
|
||
|
&padding_l[0], &padding_r[0]),
|
||
|
"could not create a descriptor for a pooling backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a pooling backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a pooling backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a pooling backward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a pooling forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const pooling_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a pooling backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a pooling backward propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a pooling forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const pooling_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a pooling backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a pooling backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
pooling_backward() = default;
|
||
|
|
||
|
/// Constructs a pooling backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a pooling backward propagation
|
||
|
/// primitive.
|
||
|
pooling_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_pooling
|
||
|
|
||
|
/// @addtogroup dnnl_api_eltwise Eltwise
|
||
|
///
|
||
|
/// A primitive to perform elementwise operations such as the
|
||
|
/// rectifier linear unit (ReLU).
|
||
|
///
|
||
|
/// Both forward and backward propagation primitives support in-place
|
||
|
/// operation; that is, src and dst can refer to the same memory for forward
|
||
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
||
|
/// backward propagation.
|
||
|
///
|
||
|
/// @warning
|
||
|
/// Because the original source data is required for backward propagation,
|
||
|
/// in-place forward propagation is not generally supported in the
|
||
|
/// training mode. However, for algorithms supporting destination as input
|
||
|
/// memory, dst can be used for the backward propagation, which makes it
|
||
|
/// possible to get performance benefit even in the training mode.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_eltwise in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Elementwise unary operation forward propagation primitive.
|
||
|
struct eltwise_forward : public primitive {
|
||
|
/// Descriptor for an elementwise forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_eltwise_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an elementwise forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm Elementwise algorithm kind.
|
||
|
/// @param data_desc Source and destination memory descriptors.
|
||
|
/// @param alpha The alpha parameter for the elementwise operation.
|
||
|
/// Specific meaning depends on the algorithm.
|
||
|
/// @param beta The beta parameter for the elementwise operation.
|
||
|
/// Specific meaning depends on the algorithm.
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &data_desc, float alpha = 0,
|
||
|
float beta = 0) {
|
||
|
error::wrap_c_api(dnnl_eltwise_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(algorithm),
|
||
|
&data_desc.data, alpha, beta),
|
||
|
"could not create a descriptor for an eltwise forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an elementwise forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an elementwise forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an elementwise forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an elementwise forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an elementwise forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an eltwise forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an eltwise forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
eltwise_forward() = default;
|
||
|
|
||
|
/// Constructs an eltwise forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an eltwise forward propagation
|
||
|
/// primitive.
|
||
|
eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Elementwise unary operation backward propagation primitive.
|
||
|
struct eltwise_backward : public primitive {
|
||
|
/// Descriptor for an elementwise backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_eltwise_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an elementwise backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param algorithm Elementwise algorithm kind.
|
||
|
/// @param diff_data_desc Diff source and destination memory
|
||
|
/// descriptors.
|
||
|
/// @param data_desc Source memory descriptor.
|
||
|
/// @param alpha The alpha parameter for the elementwise operation.
|
||
|
/// Specific meaning depends on the algorithm.
|
||
|
/// @param beta The beta parameter for the elementwise operation.
|
||
|
/// Specific meaning depends on the algorithm.
|
||
|
desc(algorithm algorithm, const memory::desc &diff_data_desc,
|
||
|
const memory::desc &data_desc, float alpha = 0,
|
||
|
float beta = 0) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_eltwise_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(algorithm), &diff_data_desc.data,
|
||
|
&data_desc.data, alpha, beta),
|
||
|
"could not create a descriptor for an eltwise backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for eltwise backward propagation.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an elementwise backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an elementwise backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an elementwise forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const eltwise_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an elementwise backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an elementwise backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an elementwise forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const eltwise_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an eltwise backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an eltwise backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
eltwise_backward() = default;
|
||
|
|
||
|
/// Constructs an eltwise backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an eltwise backward propagation
|
||
|
/// primitive.
|
||
|
eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_eltwise
|
||
|
|
||
|
/// @addtogroup dnnl_api_softmax Softmax
|
||
|
///
|
||
|
/// A primitive to perform softmax.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_softmax in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Softmax forward propagation primitive.
|
||
|
struct softmax_forward : public primitive {
|
||
|
/// Descriptor for a softmax forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_softmax_desc_t data;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
desc() = default;
|
||
|
|
||
|
/// Constructs a descriptor for a softmax forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param data_desc Source and destination memory descriptor.
|
||
|
/// @param softmax_axis Axis over which softmax is computed.
|
||
|
desc(prop_kind prop_kind, const memory::desc &data_desc,
|
||
|
int softmax_axis) {
|
||
|
error::wrap_c_api(dnnl_softmax_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
&data_desc.data, softmax_axis),
|
||
|
"could not create a descriptor for a softmax forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a softmax forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a softmax forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc descriptor for a softmax forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a softmax forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a softmax forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a softmax forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a softmax forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
softmax_forward() = default;
|
||
|
|
||
|
/// Constructs a softmax forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a softmax forward propagation
|
||
|
/// primitive.
|
||
|
softmax_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Softmax backward propagation primitive.
|
||
|
struct softmax_backward : public primitive {
|
||
|
/// Descriptor for a softmax backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_softmax_desc_t data;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
desc() = default;
|
||
|
|
||
|
/// Constructs a descriptor for a softmax backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptor.
|
||
|
/// @param data_desc Destination memory descriptor.
|
||
|
/// @param softmax_axis Axis over which softmax is computed.
|
||
|
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
|
||
|
int softmax_axis) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_softmax_backward_desc_init(&data, &diff_data_desc.data,
|
||
|
&data_desc.data, softmax_axis),
|
||
|
"could not create a descriptor for a softmax backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a softmax backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a softmax backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a softmax backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a softmax forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const softmax_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a softmax backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a softmax backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a softmax forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const softmax_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a softmax backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a softmax backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
softmax_backward() = default;
|
||
|
|
||
|
/// Constructs a softmax backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a softmax backward propagation
|
||
|
/// primitive.
|
||
|
softmax_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_softmax
|
||
|
|
||
|
/// @addtogroup dnnl_api_logsoftmax LogSoftmax
|
||
|
///
|
||
|
/// A primitive to perform logsoftmax.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_logsoftmax in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Logsoftmax forward propagation primitive.
|
||
|
struct logsoftmax_forward : public primitive {
|
||
|
/// Descriptor for a logsoftmax forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_logsoftmax_desc_t data;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
desc() = default;
|
||
|
|
||
|
/// Constructs a descriptor for a logsoftmax forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param data_desc Source and destination memory descriptor.
|
||
|
/// @param logsoftmax_axis Axis over which softmax is computed.
|
||
|
desc(prop_kind prop_kind, const memory::desc &data_desc,
|
||
|
int logsoftmax_axis) {
|
||
|
error::wrap_c_api(dnnl_logsoftmax_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
&data_desc.data, logsoftmax_axis),
|
||
|
"could not create a descriptor for a logsoftmax forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a logsoftmax forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a logsoftmax forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc descriptor for a logsoftmax forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a logsoftmax forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a logsoftmax forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a logsoftmax forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a logsoftmax forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd,
|
||
|
// Logsoftmax and softmax share the implementation and
|
||
|
// currently report the same primitive kind. Hence this
|
||
|
// must be softmax and not logsoftmax.
|
||
|
dnnl::primitive::kind::softmax,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
logsoftmax_forward() = default;
|
||
|
|
||
|
/// Constructs a logsoftmax forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a logsoftmax forward propagation
|
||
|
/// primitive.
|
||
|
logsoftmax_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Logsoftmax backward propagation primitive.
|
||
|
struct logsoftmax_backward : public primitive {
|
||
|
/// Descriptor for a logsoftmax backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_logsoftmax_desc_t data;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
desc() = default;
|
||
|
|
||
|
/// Constructs a descriptor for a logsoftmax backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptors.
|
||
|
/// @param data_desc Destination memory descriptor.
|
||
|
/// @param logsoftmax_axis Axis over which softmax is computed.
|
||
|
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
|
||
|
int logsoftmax_axis) {
|
||
|
error::wrap_c_api(dnnl_logsoftmax_backward_desc_init(&data,
|
||
|
&diff_data_desc.data, &data_desc.data,
|
||
|
logsoftmax_axis),
|
||
|
"could not create a descriptor for a logsoftmax backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a logsoftmax backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a logsoftmax backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a logsoftmax backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a logsoftmax forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const logsoftmax_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a logsoftmax backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a logsoftmax backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a logsoftmax forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const logsoftmax_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a logsoftmax backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a logsoftmax backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd,
|
||
|
// Logsoftmax and softmax share the implementation and
|
||
|
// currently report the same primitive kind. Hence this
|
||
|
// must be softmax and not logsoftmax.
|
||
|
dnnl::primitive::kind::softmax,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
logsoftmax_backward() = default;
|
||
|
|
||
|
/// Constructs a logsoftmax backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a logsoftmax backward propagation
|
||
|
/// primitive.
|
||
|
logsoftmax_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_logsoftmax
|
||
|
|
||
|
/// @addtogroup dnnl_api_batch_normalization Batch Normalization
|
||
|
///
|
||
|
/// A primitive to perform batch normalization.
|
||
|
///
|
||
|
/// Both forward and backward propagation primitives support in-place
|
||
|
/// operation; that is, src and dst can refer to the same memory for forward
|
||
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
||
|
/// backward propagation.
|
||
|
///
|
||
|
/// The batch normalization primitives computations can be controlled by
|
||
|
/// specifying different @ref dnnl::normalization_flags values. For example,
|
||
|
/// batch normalization can compute the mean and variance on its own or take
|
||
|
/// them as inputs. It can either perform scaling and shifting using gamma
|
||
|
/// and beta parameters or not. Optionally, it can also perform a fused ReLU,
|
||
|
/// which in case of training would also require a workspace.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_batch_normalization in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Batch normalization forward propagation primitive.
|
||
|
struct batch_normalization_forward : public primitive {
|
||
|
/// Descriptor for a batch normalization forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_batch_normalization_desc_t data;
|
||
|
|
||
|
/// Constructs a batch normalization descriptor for forward
|
||
|
/// propagation.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is set
|
||
|
/// in @p flags
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::dst_desc(`1`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// not set in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::dst_desc(`2`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// not set in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::fuse_norm_relu bit-flag is set
|
||
|
/// in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training; must be queried
|
||
|
/// for using @ref primitive_desc_base::query_md() after a
|
||
|
/// corresponding primitive descriptor is created
|
||
|
///
|
||
|
/// @note
|
||
|
/// In-place operation is supported: the dst can refer to the same
|
||
|
/// memory as the src.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param data_desc Source and destination memory descriptors.
|
||
|
/// @param epsilon Batch normalization epsilon parameter.
|
||
|
/// @param flags Batch normalization flags (@ref
|
||
|
/// dnnl::normalization_flags).
|
||
|
desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon,
|
||
|
normalization_flags flags) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_batch_normalization_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind), &data_desc.data,
|
||
|
epsilon, convert_to_c(flags)),
|
||
|
"could not create a descriptor for a batch normalization "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a batch normalization forward propagation
|
||
|
/// primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a batch normalization forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a batch normalization forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a batch normalization forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a batch normalization forward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a batch normalization
|
||
|
/// forward propagation primitive from a C API primitive descriptor
|
||
|
/// that must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a batch normalization
|
||
|
/// forward propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd,
|
||
|
dnnl::primitive::kind::batch_normalization,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
|
||
|
/// Returns memory descriptor for mean.
|
||
|
/// @returns Memory descriptor for mean.
|
||
|
memory::desc mean_desc() const { return stat_desc(mean); }
|
||
|
|
||
|
/// Returns memory descriptor for variance.
|
||
|
/// @returns Memory descriptor for variance.
|
||
|
memory::desc variance_desc() const { return stat_desc(var); }
|
||
|
|
||
|
private:
|
||
|
enum {
|
||
|
mean = 1,
|
||
|
var = 2,
|
||
|
};
|
||
|
memory::desc stat_desc(int kind) const {
|
||
|
dnnl_batch_normalization_desc_t *p;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_desc_query(get(),
|
||
|
dnnl::convert_to_c(query::batch_normalization_d), 0,
|
||
|
&p),
|
||
|
"could not retrieve a descriptor from a primitive "
|
||
|
"descriptor for batch normalization forward propagation "
|
||
|
"primitive");
|
||
|
return query_md(p->flags & dnnl_use_global_stats ? query::src_md
|
||
|
: query::dst_md,
|
||
|
kind);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
batch_normalization_forward() = default;
|
||
|
|
||
|
/// Constructs a batch normalization forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a batch normalization forward
|
||
|
/// propagation primitive.
|
||
|
batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Batch normalization backward propagation primitive.
|
||
|
struct batch_normalization_backward : public primitive {
|
||
|
/// Descriptor for a batch normalization backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_batch_normalization_desc_t data;
|
||
|
|
||
|
/// Constructs a batch normalization descriptor for backward
|
||
|
/// propagation.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`))
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::fuse_norm_relu bit-flag is set
|
||
|
/// in @p flags
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_scale_and_shift`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
|
||
|
/// set in @p flags and @p prop_kind = #dnnl::prop_kind::backward
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
||
|
/// (diffs for all parameters are computed in this case).
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptor.
|
||
|
/// @param data_desc Source memory descriptor.
|
||
|
/// @param epsilon Batch normalization epsilon parameter.
|
||
|
/// @param flags Batch normalization flags (@ref
|
||
|
/// dnnl::normalization_flags).
|
||
|
desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
|
||
|
const memory::desc &data_desc, float epsilon,
|
||
|
normalization_flags flags) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_batch_normalization_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind), &diff_data_desc.data,
|
||
|
&data_desc.data, epsilon, convert_to_c(flags)),
|
||
|
"could not create a descriptor for a batch normalization "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a batch normalization backward propagation
|
||
|
/// primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a batch normalization backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a batch normalization backward
|
||
|
/// propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a batch normalization
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const batch_normalization_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a batch normalization backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a batch normalization backward
|
||
|
/// propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a batch normalization
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const batch_normalization_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a batch normalization
|
||
|
/// backward propagation primitive from a C API primitive descriptor
|
||
|
/// that must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a batch normalization
|
||
|
/// backward propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd,
|
||
|
dnnl::primitive::kind::batch_normalization,
|
||
|
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
||
|
memory::desc diff_weights_desc() const {
|
||
|
return base::diff_weights_desc(0);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
|
||
|
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
|
||
|
|
||
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
|
||
|
memory::desc variance_desc() const {
|
||
|
return query_md(query::src_md, 2);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
batch_normalization_backward() = default;
|
||
|
|
||
|
/// Constructs a batch normalization backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a batch normalization backward
|
||
|
/// propagation primitive.
|
||
|
batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_batch_normalization
|
||
|
|
||
|
/// @addtogroup dnnl_api_layer_normalization Layer Normalization
|
||
|
///
|
||
|
/// A primitive to perform layer normalization. Normalization is performed
|
||
|
/// within the last logical dimension of data tensor.
|
||
|
///
|
||
|
/// Both forward and backward propagation primitives support in-place
|
||
|
/// operation; that is, src and dst can refer to the same memory for forward
|
||
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
||
|
/// backward propagation.
|
||
|
///
|
||
|
/// The layer normalization primitives computations can be controlled by
|
||
|
/// specifying different dnnl::normalization_flags values. For example,
|
||
|
/// layer normalization forward propagation can be configured to either
|
||
|
/// compute the mean and variance or take them as arguments. It can either
|
||
|
/// perform scaling and shifting using gamma and beta parameters or not.
|
||
|
/// Optionally, it can also perform a fused ReLU, which in case of training
|
||
|
/// would also require a workspace.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_layer_normalization in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Layer normalization forward propagation primitive.
|
||
|
struct layer_normalization_forward : public primitive {
|
||
|
/// Descriptor for a layer normalization forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_layer_normalization_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for layer normalization forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is set
|
||
|
/// in @p flags
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::dst_desc(`1`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// not set in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::dst_desc(`2`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// not set in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param data_desc Source and destination memory descriptor.
|
||
|
/// @param stat_desc Statistics memory descriptors.
|
||
|
/// @param epsilon Layer normalization epsilon parameter.
|
||
|
/// @param flags Layer normalization flags (@ref
|
||
|
/// dnnl::normalization_flags).
|
||
|
desc(prop_kind prop_kind, const memory::desc &data_desc,
|
||
|
const memory::desc &stat_desc, float epsilon,
|
||
|
normalization_flags flags) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_layer_normalization_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind), &data_desc.data,
|
||
|
&stat_desc.data, epsilon, convert_to_c(flags)),
|
||
|
"could not create a descriptor for a layer normalization "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for layer normalization forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// set in @p flags
|
||
|
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is set
|
||
|
/// in @p flags
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::dst_desc(`1`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// not set in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::dst_desc(`2`)),
|
||
|
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
|
||
|
/// not set in @p flags and @p prop_kind =
|
||
|
/// #dnnl::prop_kind::forward_training
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param data_desc Source and destination memory descriptor.
|
||
|
/// @param epsilon Layer normalization epsilon parameter.
|
||
|
/// @param flags Layer normalization flags (@ref
|
||
|
/// dnnl::normalization_flags).
|
||
|
desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon,
|
||
|
normalization_flags flags) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_layer_normalization_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind), &data_desc.data,
|
||
|
nullptr, epsilon, convert_to_c(flags)),
|
||
|
"could not create a descriptor for a layer normalization "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a layer normalization forward propagation
|
||
|
/// primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a layer normalization forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a layer normalization forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a layer normalization forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a layer normalization forward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a layer normalization
|
||
|
/// forward propagation primitive from a C API primitive descriptor
|
||
|
/// that must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a layer normalization
|
||
|
/// forward propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd,
|
||
|
dnnl::primitive::kind::layer_normalization,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
|
||
|
memory::desc mean_desc() const { return stat_desc(mean); }
|
||
|
|
||
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
|
||
|
memory::desc variance_desc() const { return stat_desc(var); }
|
||
|
|
||
|
private:
|
||
|
enum {
|
||
|
mean = 1,
|
||
|
var = 2,
|
||
|
};
|
||
|
memory::desc stat_desc(int kind) const {
|
||
|
dnnl_layer_normalization_desc_t *p;
|
||
|
error::wrap_c_api(
|
||
|
dnnl_primitive_desc_query(get(),
|
||
|
dnnl::convert_to_c(query::layer_normalization_d), 0,
|
||
|
&p),
|
||
|
"could not retrieve a descriptor from a primitive "
|
||
|
"descriptor for layer normalization forward propagation "
|
||
|
"primitive");
|
||
|
return query_md(p->flags & dnnl_use_global_stats ? query::src_md
|
||
|
: query::dst_md,
|
||
|
kind);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
layer_normalization_forward() = default;
|
||
|
|
||
|
/// Constructs a layer normalization forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a layer normalization forward
|
||
|
/// propagation primitive.
|
||
|
layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Layer normalization backward propagation primitive.
|
||
|
struct layer_normalization_backward : public primitive {
|
||
|
/// Descriptor for a layer normalization backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_layer_normalization_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`))
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
|
||
|
/// set in @p flags
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_scale_and_shift`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`)), if
|
||
|
/// #dnnl::normalization_flags::use_scale_shift bit-flag is set
|
||
|
/// in @p flags and @p prop_kind = #dnnl::prop_kind::backward
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
||
|
/// (diffs for all parameters are computed in this case).
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptor.
|
||
|
/// @param data_desc Source memory descriptor.
|
||
|
/// @param stat_desc Statistics memory descriptors.
|
||
|
/// @param epsilon Layer normalization epsilon parameter.
|
||
|
/// @param flags Layer normalization flags (@ref
|
||
|
/// dnnl::normalization_flags).
|
||
|
desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
|
||
|
const memory::desc &data_desc, const memory::desc &stat_desc,
|
||
|
float epsilon, normalization_flags flags) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_layer_normalization_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind), &diff_data_desc.data,
|
||
|
&data_desc.data, &stat_desc.data, epsilon,
|
||
|
convert_to_c(flags)),
|
||
|
"could not create a descriptor for a batch normalization "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`))
|
||
|
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
|
||
|
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
|
||
|
/// set in @p flags
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_scale_and_shift`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`)), if
|
||
|
/// #dnnl::normalization_flags::use_scale_shift bit-flag is set
|
||
|
/// in @p flags and @p prop_kind = #dnnl::prop_kind::backward
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
||
|
/// (diffs for all parameters are computed in this case).
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptor.
|
||
|
/// @param data_desc Source memory descriptor.
|
||
|
/// @param epsilon Layer normalization epsilon parameter.
|
||
|
/// @param flags Layer normalization flags (@ref
|
||
|
/// dnnl::normalization_flags).
|
||
|
desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
|
||
|
const memory::desc &data_desc, float epsilon,
|
||
|
normalization_flags flags) {
|
||
|
error::wrap_c_api(dnnl_layer_normalization_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
&diff_data_desc.data, &data_desc.data,
|
||
|
nullptr, epsilon, convert_to_c(flags)),
|
||
|
"could not create a descriptor for a batch normalization "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a layer normalization backward propagation
|
||
|
/// primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a layer normalization
|
||
|
/// backward propagation primitive from a C API primitive descriptor
|
||
|
/// that must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a layer normalization
|
||
|
/// backward propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd,
|
||
|
dnnl::primitive::kind::layer_normalization,
|
||
|
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
||
|
memory::desc diff_weights_desc() const {
|
||
|
return base::diff_weights_desc(0);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
|
||
|
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
|
||
|
|
||
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
|
||
|
memory::desc variance_desc() const {
|
||
|
return query_md(query::src_md, 2);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
layer_normalization_backward() = default;
|
||
|
|
||
|
/// Constructs a layer normalization backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a layer normalization backward
|
||
|
/// propagation primitive.
|
||
|
layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_layer_normalization
|
||
|
|
||
|
/// @addtogroup dnnl_api_inner_product Inner Product
|
||
|
///
|
||
|
/// A primitive to compute an inner product.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_inner_product in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Inner product forward propagation primitive.
|
||
|
struct inner_product_forward : public primitive {
|
||
|
/// Descriptor for an inner product forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_inner_product_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an inner product forward propagation
|
||
|
/// primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param src_desc Memory descriptor for src.
|
||
|
/// @param weights_desc Memory descriptor for diff weights.
|
||
|
/// @param bias_desc Memory descriptor for diff bias.
|
||
|
/// @param dst_desc Memory descriptor for diff dst.
|
||
|
desc(prop_kind prop_kind, const memory::desc &src_desc,
|
||
|
const memory::desc &weights_desc, const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_desc) {
|
||
|
error::wrap_c_api(dnnl_inner_product_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
&src_desc.data, &weights_desc.data,
|
||
|
&bias_desc.data, &dst_desc.data),
|
||
|
"could not create a descriptor for an inner product "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for an inner product forward propagation
|
||
|
/// primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param src_desc Memory descriptor for src.
|
||
|
/// @param weights_desc Memory descriptor for diff weights.
|
||
|
/// @param dst_desc Memory descriptor for dst.
|
||
|
desc(prop_kind prop_kind, const memory::desc &src_desc,
|
||
|
const memory::desc &weights_desc,
|
||
|
const memory::desc &dst_desc) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_inner_product_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind), &src_desc.data,
|
||
|
&weights_desc.data, nullptr, &dst_desc.data),
|
||
|
"could not create a descriptor for an inner product "
|
||
|
"forward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an inner product forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an inner product forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an inner product forward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an inner product forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
|
||
|
memory::desc bias_desc() const { return base::weights_desc(1); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
inner_product_forward() = default;
|
||
|
|
||
|
/// Constructs an inner product forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an inner product forward
|
||
|
/// propagation primitive.
|
||
|
inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Inner product backward propagation primitive.
|
||
|
struct inner_product_backward_data : public primitive {
|
||
|
/// Descriptor for an inner product backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_inner_product_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an inner product backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param diff_src_desc Memory descriptor for diff src.
|
||
|
/// @param weights_desc Memory descriptor for weights.
|
||
|
/// @param diff_dst_desc Memory descriptor for diff dst.
|
||
|
desc(const memory::desc &diff_src_desc,
|
||
|
const memory::desc &weights_desc,
|
||
|
const memory::desc &diff_dst_desc) {
|
||
|
error::wrap_c_api(dnnl_inner_product_backward_data_desc_init(&data,
|
||
|
&diff_src_desc.data, &weights_desc.data,
|
||
|
&diff_dst_desc.data),
|
||
|
"could not create a descriptor for an inner product "
|
||
|
"backward propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an inner product backward propagation
|
||
|
/// primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an inner product backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an inner product backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an inner product backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
inner_product_backward_data() = default;
|
||
|
|
||
|
/// Constructs an inner product backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an inner product backward
|
||
|
/// propagation primitive.
|
||
|
inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Inner product weights gradient primitive.
|
||
|
struct inner_product_backward_weights : public primitive {
|
||
|
/// Descriptor for an inner product weights gradient primitive.
|
||
|
struct desc {
|
||
|
dnnl_inner_product_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an inner product descriptor weights
|
||
|
/// update primitive with bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param src_desc Memory descriptor for src.
|
||
|
/// @param diff_weights_desc Memory descriptor for diff weights.
|
||
|
/// @param diff_bias_desc Memory descriptor for diff bias.
|
||
|
/// @param diff_dst_desc Memory descriptor for diff dst.
|
||
|
desc(const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_desc) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_inner_product_backward_weights_desc_init(&data,
|
||
|
&src_desc.data, &diff_weights_desc.data,
|
||
|
&diff_bias_desc.data, &diff_dst_desc.data),
|
||
|
"could not create a descriptor for an inner product "
|
||
|
"weights gradient primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for an inner product descriptor weights
|
||
|
/// update primitive without bias.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param src_desc Memory descriptor for src.
|
||
|
/// @param diff_weights_desc Memory descriptor for diff weights.
|
||
|
/// @param diff_dst_desc Memory descriptor for diff dst.
|
||
|
desc(const memory::desc &src_desc,
|
||
|
const memory::desc &diff_weights_desc,
|
||
|
const memory::desc &diff_dst_desc) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_inner_product_backward_weights_desc_init(&data,
|
||
|
&src_desc.data, &diff_weights_desc.data, nullptr,
|
||
|
&diff_dst_desc.data),
|
||
|
"could not create a descriptor for an inner product "
|
||
|
"weights gradient primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an inner product weights gradient primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product weights
|
||
|
/// update primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an inner product weights gradient
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product weights
|
||
|
/// update primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an inner product weights gradient
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an inner product weights
|
||
|
/// update primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an inner product weights
|
||
|
/// gradient primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
|
||
|
dnnl::prop_kind::backward_weights) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
||
|
memory::desc diff_weights_desc() const {
|
||
|
return base::diff_weights_desc(0);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return base::diff_weights_desc(1);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
inner_product_backward_weights() = default;
|
||
|
|
||
|
/// Constructs an inner product weights gradient primitive.
|
||
|
/// @param pd Primitive descriptor for an inner product weights gradient
|
||
|
/// primitive.
|
||
|
inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_inner_product
|
||
|
|
||
|
/// @addtogroup dnnl_api_rnn RNN
|
||
|
///
|
||
|
/// A primitive to compute recurrent neural network layers.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_rnn in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Base class for primitive descriptors for RNN primitives.
|
||
|
struct rnn_primitive_desc_base : public primitive_desc {
|
||
|
using primitive_desc::primitive_desc;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
rnn_primitive_desc_base() = default;
|
||
|
|
||
|
/// Constructs an RNN primitive descriptor base from a C API primitive
|
||
|
/// descriptor while checking that it actually describes the expected
|
||
|
/// primitive by comparing propagation and primitive kinds.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor.
|
||
|
/// @param prop_kind Expected propagation kind.
|
||
|
/// @param cell_kind Expected cell kind.
|
||
|
rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind prop_kind,
|
||
|
dnnl::algorithm cell_kind)
|
||
|
: rnn_primitive_desc_base(pd, prop_kind, prop_kind, cell_kind) {}
|
||
|
|
||
|
/// Returns source layer memory descriptor.
|
||
|
/// @returns Source layer memory descriptor.
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
|
||
|
}
|
||
|
|
||
|
/// Returns source iteration memory descriptor.
|
||
|
/// @returns Source iteration memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// source iteration parameter.
|
||
|
memory::desc src_iter_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
|
||
|
}
|
||
|
|
||
|
/// Returns source recurrent cell state memory descriptor.
|
||
|
/// @returns Source recurrent cell state memory descriptor.
|
||
|
memory::desc src_iter_c_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
|
||
|
}
|
||
|
|
||
|
/// Returns weights layer memory descriptor.
|
||
|
/// @returns Weights layer memory descriptor.
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
|
||
|
}
|
||
|
|
||
|
/// Returns weights iteration memory descriptor.
|
||
|
/// @returns Weights iteration memory descriptor.
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
|
||
|
}
|
||
|
|
||
|
/// Returns weights peephole memory descriptor.
|
||
|
/// @returns Weights peephole memory descriptor.
|
||
|
memory::desc weights_peephole_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
|
||
|
}
|
||
|
|
||
|
/// Returns weights projection memory descriptor.
|
||
|
/// @returns Weights projection memory descriptor.
|
||
|
memory::desc weights_projection_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
|
||
|
}
|
||
|
|
||
|
/// Returns bias memory descriptor.
|
||
|
/// @returns Bias memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// bias parameter.
|
||
|
memory::desc bias_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
|
||
|
}
|
||
|
|
||
|
/// Returns destination layer memory descriptor.
|
||
|
/// @returns Destination layer memory descriptor.
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
|
||
|
}
|
||
|
|
||
|
/// Returns destination iteration memory descriptor.
|
||
|
/// @returns Destination iteration memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// destination iteration parameter.
|
||
|
memory::desc dst_iter_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
|
||
|
}
|
||
|
|
||
|
/// Returns destination recurrent cell state memory descriptor.
|
||
|
/// @returns Destination recurrent cell state memory descriptor.
|
||
|
memory::desc dst_iter_c_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
|
||
|
}
|
||
|
|
||
|
/// Returns diff source layer memory descriptor.
|
||
|
/// @returns Diff source layer memory descriptor.
|
||
|
memory::desc diff_src_layer_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
|
||
|
}
|
||
|
|
||
|
/// Returns diff source iteration memory descriptor.
|
||
|
/// @returns Diff source iteration memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff source iteration parameter.
|
||
|
memory::desc diff_src_iter_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
|
||
|
}
|
||
|
|
||
|
/// Returns diff source recurrent cell state memory descriptor.
|
||
|
/// @returns Diff source recurrent cell state memory descriptor.
|
||
|
memory::desc diff_src_iter_c_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
|
||
|
}
|
||
|
|
||
|
/// Returns diff weights layer memory descriptor.
|
||
|
/// @returns Diff weights layer memory descriptor.
|
||
|
memory::desc diff_weights_layer_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
|
||
|
}
|
||
|
|
||
|
/// Returns diff weights iteration memory descriptor.
|
||
|
/// @returns Diff weights iteration memory descriptor.
|
||
|
memory::desc diff_weights_iter_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
|
||
|
}
|
||
|
|
||
|
/// Returns diff weights peephole memory descriptor.
|
||
|
/// @returns Diff weights peephole memory descriptor.
|
||
|
memory::desc diff_weights_peephole_desc() const {
|
||
|
return base::query_md(
|
||
|
query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
|
||
|
}
|
||
|
|
||
|
/// Returns diff weights projection memory descriptor.
|
||
|
/// @returns Diff weights projection memory descriptor.
|
||
|
memory::desc diff_weights_projection_desc() const {
|
||
|
return base::query_md(
|
||
|
query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
|
||
|
}
|
||
|
|
||
|
/// Returns diff bias memory descriptor.
|
||
|
/// @returns Diff bias memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff bias parameter.
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
|
||
|
}
|
||
|
|
||
|
/// Returns diff destination layer memory descriptor.
|
||
|
/// @returns Diff destination layer memory descriptor.
|
||
|
memory::desc diff_dst_layer_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
|
||
|
}
|
||
|
|
||
|
/// Returns diff destination iteration memory descriptor.
|
||
|
/// @returns Diff destination iteration memory descriptor.
|
||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||
|
/// diff destination iteration parameter.
|
||
|
memory::desc diff_dst_iter_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
|
||
|
}
|
||
|
|
||
|
/// Returns diff destination recurrent cell state memory descriptor.
|
||
|
/// @returns Diff destination recurrent cell state memory descriptor.
|
||
|
memory::desc diff_dst_iter_c_desc() const {
|
||
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
|
||
|
}
|
||
|
|
||
|
protected:
|
||
|
using rnn_base = rnn_primitive_desc_base;
|
||
|
|
||
|
// (Deliberately not using doxygen comments)
|
||
|
//
|
||
|
// Constructs an RNN primitive descriptor base from a C API primitive
|
||
|
// descriptor while checking that it actually describes the expected
|
||
|
// primitive by comparing propagation and primitive kinds. Caller can
|
||
|
// pass two options propagation kinds. This is typically used to check
|
||
|
// that propagation kind is inference or training forward propagation.
|
||
|
//
|
||
|
// @param pd C API primitive descriptor.
|
||
|
// @param prop_kind1 Expected propagation kind.
|
||
|
// @param prop_kind2 Expected propagation kind.
|
||
|
// @param cell_kind Expected cell kind.
|
||
|
rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
|
||
|
dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
|
||
|
dnnl::algorithm cell_kind) {
|
||
|
dnnl_rnn_desc_t *rnn_d;
|
||
|
dnnl_status_t rc;
|
||
|
rc = dnnl_primitive_desc_query(pd, dnnl_query_rnn_d, 0, &rnn_d);
|
||
|
error::wrap_c_api(rc,
|
||
|
"could not retrieve a descriptor from a primitive descriptor "
|
||
|
"for an RNN primitive");
|
||
|
|
||
|
dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
|
||
|
dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
|
||
|
dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
|
||
|
|
||
|
bool ok = rnn_d->primitive_kind == dnnl_rnn
|
||
|
&& (rnn_d->prop_kind == c_prop_kind1
|
||
|
|| rnn_d->prop_kind == c_prop_kind2)
|
||
|
&& rnn_d->cell_kind == c_cell_kind;
|
||
|
|
||
|
if (!ok)
|
||
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
||
|
"mismatch between expected and provided descriptors for an "
|
||
|
"RNN primitive");
|
||
|
|
||
|
reset_with_clone(pd);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Vanilla RNN forward propagation primitive.
|
||
|
struct vanilla_rnn_forward : public primitive {
|
||
|
/// Descriptor for a vanilla RNN forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a vanilla RNN forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc, @p bias_desc, and @p dst_iter_desc may point
|
||
|
/// to a zero memory descriptor. This would then indicate that the RNN
|
||
|
/// forward propagation primitive should not use them and should
|
||
|
/// default to zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
|
||
|
/// must be queried for using @ref
|
||
|
/// dnnl::primitive_desc_base::query_md() after a corresponding
|
||
|
/// primitive descriptor is created
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors except @p src_iter_desc can be
|
||
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
||
|
/// format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param activation Activation kind. Possible values are
|
||
|
/// #dnnl::algorithm::eltwise_relu,
|
||
|
/// #dnnl::algorithm::eltwise_tanh, or
|
||
|
/// #dnnl::algorithm::eltwise_logistic.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param flags Unused.
|
||
|
/// @param alpha Negative slope if activation is
|
||
|
/// #dnnl::algorithm::eltwise_relu.
|
||
|
/// @param beta Unused.
|
||
|
desc(prop_kind prop_kind, algorithm activation, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
|
||
|
float beta = 0.0f) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_vanilla_rnn_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(activation),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &weights_layer_desc.data,
|
||
|
&weights_iter_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
dnnl::convert_to_c(flags), alpha, beta),
|
||
|
"could not create a descriptor for a vanilla RNN forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a vanilla RNN forward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a vanilla RNN forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a vanilla RNN forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a vanilla RNN forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a vanilla RNN forward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a vanilla RNN forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a vanilla RNN forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference,
|
||
|
dnnl::algorithm::vanilla_rnn) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
vanilla_rnn_forward() = default;
|
||
|
|
||
|
/// Constructs a vanilla RNN forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a vanilla RNN forward
|
||
|
/// propagation primitive.
|
||
|
vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Vanilla RNN backward propagation primitive.
|
||
|
struct vanilla_rnn_backward : public primitive {
|
||
|
/// Vanilla RNN descriptor backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a vanilla RNN backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc together with @p diff_src_iter_desc, @p
|
||
|
/// bias_desc together with @p diff_bias_desc, and @p dst_iter_desc
|
||
|
/// together with @p diff_src_iter_desc, may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the RNN backward
|
||
|
/// propagation primitive should not use the respective data and
|
||
|
/// should use zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `diff_dst_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src_layer`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_src_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
|
||
|
/// - `diff_weights_layer`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_weights_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
/// - `diff_bias`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// @note
|
||
|
/// All the memory descriptors may be initialized with the
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Must be
|
||
|
/// #dnnl::prop_kind::backward.
|
||
|
/// @param activation Activation kind. Possible values are
|
||
|
/// #dnnl::algorithm::eltwise_relu,
|
||
|
/// #dnnl::algorithm::eltwise_tanh, or
|
||
|
/// #dnnl::algorithm::eltwise_logistic.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
||
|
/// vector.
|
||
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the layer input.
|
||
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the recurrent input.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
||
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
||
|
/// output vector.
|
||
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param flags Unused.
|
||
|
/// @param alpha Negative slope if activation is
|
||
|
/// #dnnl::algorithm::eltwise_relu.
|
||
|
/// @param beta Unused.
|
||
|
desc(prop_kind prop_kind, algorithm activation, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &diff_src_layer_desc,
|
||
|
const memory::desc &diff_src_iter_desc,
|
||
|
const memory::desc &diff_weights_layer_desc,
|
||
|
const memory::desc &diff_weights_iter_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_layer_desc,
|
||
|
const memory::desc &diff_dst_iter_desc,
|
||
|
rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
|
||
|
float beta = 0.0f) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_vanilla_rnn_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(activation),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &weights_layer_desc.data,
|
||
|
&weights_iter_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
|
||
|
&diff_weights_layer_desc.data,
|
||
|
&diff_weights_iter_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
|
||
|
dnnl::convert_to_c(flags), alpha, beta),
|
||
|
"could not create a descriptor for a vanilla RNN backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a RNN backward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a vanilla RNN backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a vanilla RNN backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a vanilla RNN backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a vanilla RNN backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, &attr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a vanilla RNN backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a vanilla RNN backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
||
|
dnnl::algorithm::vanilla_rnn) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
||
|
memory::desc diff_src_layer_desc() const {
|
||
|
return rnn_base::diff_src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
||
|
memory::desc diff_src_iter_desc() const {
|
||
|
return rnn_base::diff_src_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
||
|
memory::desc diff_weights_layer_desc() const {
|
||
|
return rnn_base::diff_weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
||
|
memory::desc diff_weights_iter_desc() const {
|
||
|
return rnn_base::diff_weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return rnn_base::diff_bias_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
||
|
memory::desc diff_dst_layer_desc() const {
|
||
|
return rnn_base::diff_dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
||
|
memory::desc diff_dst_iter_desc() const {
|
||
|
return rnn_base::diff_dst_iter_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
vanilla_rnn_backward() = default;
|
||
|
|
||
|
/// Constructs a vanilla RNN backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a vanilla RNN backward
|
||
|
/// propagation primitive.
|
||
|
vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// LSTM forward propagation primitive.
|
||
|
struct lstm_forward : public primitive {
|
||
|
/// Descriptor for an LSTM forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for an LSTM (with or without peephole and
|
||
|
/// with or without projection) forward propagation primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc, @p src_iter_c_desc, @p weights_peephole_desc,
|
||
|
/// @p bias_desc, @p dst_iter_desc, and @p dst_iter_c_desc may point to
|
||
|
/// a zero memory descriptor. This would then indicate that the LSTM
|
||
|
/// forward propagation primitive should not use them and should
|
||
|
/// default to zero values instead.
|
||
|
///
|
||
|
/// The @p weights_projection_desc may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the LSTM doesn't have
|
||
|
/// recurrent projection layer.
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors can be initialized with an
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - src_layer (#dnnl::primitive_desc_base::src_desc (0))
|
||
|
/// - src_iter (#dnnl::primitive_desc_base::src_desc (1)), if used
|
||
|
/// - src_iter_c (#dnnl::primitive_desc_base::src_desc (2)), if used
|
||
|
/// - weights_layer (#dnnl::primitive_desc_base::weights_desc (0))
|
||
|
/// - weights_iter (#dnnl::primitive_desc_base::weights_desc (1))
|
||
|
/// - weights_peephole (#dnnl::primitive_desc_base::weights_desc (2)),
|
||
|
/// if used
|
||
|
/// - weights_projection
|
||
|
/// (#dnnl::primitive_desc_base::weights_desc (index)), if used and
|
||
|
/// index is:
|
||
|
/// - 2, if there is no weights_peephole
|
||
|
/// - 3, otherwise
|
||
|
/// - bias (#dnnl::primitive_desc_base::weights_desc (index)), if used
|
||
|
/// and index is:
|
||
|
/// - 2, if neither weights_peephole nor weights_projection is used
|
||
|
/// - 3, if one of weights_peephole or weights_projection is used
|
||
|
/// - 4, if both weights_peephole and weights_projection are used
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - dst_layer (#dnnl::primitive_desc_base::dst_desc (0))
|
||
|
/// - dst_iter (#dnnl::primitive_desc_base::dst_desc (1)), if used
|
||
|
/// - dst_iter_c (#dnnl::primitive_desc_base::dst_desc (2)), if used
|
||
|
/// - workspace (#dnnl::primitive_desc_base::workspace_desc (0)),
|
||
|
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
|
||
|
/// must be queried for using @ref
|
||
|
/// dnnl::primitive_desc_base::query_md() after a corresponding
|
||
|
/// primitive descriptor is created
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
||
|
/// applied to the cell states (according to the Peephole LSTM
|
||
|
/// formula).
|
||
|
/// @param weights_projection_desc Memory descriptor for the weights
|
||
|
/// applied to the hidden states to get the recurrent projection
|
||
|
/// (according to the Projection LSTM formula).
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &src_iter_c_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &weights_peephole_desc,
|
||
|
const memory::desc &weights_projection_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &dst_iter_c_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lstm_forward_desc_init_v3(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &src_iter_c_desc.data,
|
||
|
&weights_layer_desc.data, &weights_iter_desc.data,
|
||
|
&weights_peephole_desc.data,
|
||
|
&weights_projection_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LSTM forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for an LSTM (with or without peephole)
|
||
|
/// forward propagation primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc, @p src_iter_c_desc, @p weights_peephole_desc,
|
||
|
/// @p bias_desc, @p dst_iter_desc, and @p dst_iter_c_desc may point to
|
||
|
/// a zero memory descriptor. This would then indicate that the LSTM
|
||
|
/// forward propagation primitive should not use them and should
|
||
|
/// default to zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `weights_peephole` (#dnnl::primitive_desc_base::weights_desc(`2`)),
|
||
|
/// if used
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used and
|
||
|
/// LSTM is without peephole
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`3`)), if used and
|
||
|
/// LSTM is with peephole
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
|
||
|
/// must be queried for using @ref
|
||
|
/// dnnl::primitive_desc_base::query_md() after a corresponding
|
||
|
/// primitive descriptor is created
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors can be initialized with an
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
||
|
/// applied to the cell states (according to the Peephole LSTM
|
||
|
/// formula).
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &src_iter_c_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &weights_peephole_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &dst_iter_c_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lstm_forward_desc_init_v2(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &src_iter_c_desc.data,
|
||
|
&weights_layer_desc.data, &weights_iter_desc.data,
|
||
|
&weights_peephole_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LSTM forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for an LSTM forward propagation primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc, @p src_iter_c_desc, @p bias_desc, @p
|
||
|
/// dst_iter_desc, and @p dst_iter_c_desc may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the LSTM forward
|
||
|
/// propagation primitive should not use them and should default to
|
||
|
/// zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
|
||
|
/// must be queried for using @ref
|
||
|
/// dnnl::primitive_desc_base::query_md() after a
|
||
|
/// corresponding primitive descriptor is created
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors can be initialized with an
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &src_iter_c_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &dst_iter_c_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lstm_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &src_iter_c_desc.data,
|
||
|
&weights_layer_desc.data, &weights_iter_desc.data,
|
||
|
&bias_desc.data, &dst_layer_desc.data,
|
||
|
&dst_iter_desc.data, &dst_iter_c_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LSTM forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an LSTM forward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LSTM forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LSTM forward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LSTM forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LSTM forward propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LSTM forward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an LSTM forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference,
|
||
|
dnnl::algorithm::vanilla_lstm) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_c_desc() const {
|
||
|
return rnn_base::src_iter_c_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
|
||
|
memory::desc weights_peephole_desc() const {
|
||
|
return rnn_base::weights_peephole_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
|
||
|
memory::desc weights_projection_desc() const {
|
||
|
return rnn_base::weights_projection_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc dst_iter_c_desc() const {
|
||
|
return rnn_base::dst_iter_c_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
lstm_forward() = default;
|
||
|
|
||
|
/// Constructs an LSTM forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an LSTM forward propagation
|
||
|
/// primitive.
|
||
|
lstm_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// LSTM backward propagation primitive.
|
||
|
struct lstm_backward : public primitive {
|
||
|
/// Descriptor for an LSTM backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs an LSTM (with or without peephole and with or without
|
||
|
/// projection) descriptor for backward propagation using @p prop_kind,
|
||
|
/// @p direction, and memory descriptors.
|
||
|
///
|
||
|
/// The @p src_iter_desc together with @p diff_iter_desc, @p
|
||
|
/// src_iter_c_desc together with @p src_iter_c_desc, @p
|
||
|
/// weights_peephole_desc together with @p diff_weights_peephole_desc,
|
||
|
/// @p bias_desc together with @p diff_bias_desc, @p dst_iter_desc
|
||
|
/// together with @p diff_dst_iter_desc, and @p dst_iter_c_desc
|
||
|
/// together with @p diff_dst_iter_c_desc, may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the LSTM backward
|
||
|
/// propagation primitive should not use them and should default to
|
||
|
/// zero values instead.
|
||
|
///
|
||
|
/// The @p weights_projection_desc together with @p
|
||
|
/// diff_weights_projection_desc may point to a zero memory descriptor.
|
||
|
/// This would then indicate that the LSTM doesn't have recurrent
|
||
|
/// projection layer.
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors can be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - src_layer (#dnnl::primitive_desc_base::src_desc (0))
|
||
|
/// - src_iter (#dnnl::primitive_desc_base::src_desc (1)), if used
|
||
|
/// - src_iter_c (#dnnl::primitive_desc_base::src_desc (2)), if used
|
||
|
/// - weights_layer (#dnnl::primitive_desc_base::weights_desc (0))
|
||
|
/// - weights_iter (#dnnl::primitive_desc_base::weights_desc (1))
|
||
|
/// - weights_peephole (#dnnl::primitive_desc_base::weights_desc (2)),
|
||
|
/// if used
|
||
|
/// - weights_projection
|
||
|
/// (#dnnl::primitive_desc_base::weights_desc (index)), if used and
|
||
|
/// index is:
|
||
|
/// - 2, if there is no weights_peephole
|
||
|
/// - 3, otherwise
|
||
|
/// - bias (#dnnl::primitive_desc_base::weights_desc (index)), if used
|
||
|
/// and index is:
|
||
|
/// - 2, if neither weights_peephole nor weights_projection is used
|
||
|
/// - 3, if one of weights_peephole or weights_projection is used
|
||
|
/// - 4, if both weights_peephole and weights_projection are used
|
||
|
/// - dst_layer (#dnnl::primitive_desc_base::dst_desc (0))
|
||
|
/// - dst_iter (#dnnl::primitive_desc_base::dst_desc (1)), if used
|
||
|
/// - dst_iter_c (#dnnl::primitive_desc_base::dst_desc (2)), if used
|
||
|
/// - diff_dst_layer (#dnnl::primitive_desc_base::diff_dst_desc (0))
|
||
|
/// - diff_dst_iter
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc (1)), if used
|
||
|
/// - diff_dst_iter_c
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc (2)), if used
|
||
|
/// - workspace (#dnnl::primitive_desc_base::workspace_desc (0))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - diff_src_layer (#dnnl::primitive_desc_base::diff_src_desc (0))
|
||
|
/// - diff_src_iter
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc (1)), if used
|
||
|
/// - diff_src_iter_c
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc (2)), if used
|
||
|
/// - diff_weights_layer
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc (0))
|
||
|
/// - diff_weights_iter
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc (1))
|
||
|
/// - diff_weights_peephole
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc (2)), if used
|
||
|
/// - diff_weights_projection
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc (index)), if used
|
||
|
/// and index is:
|
||
|
/// - 2, if there is no diff_weights_peephole
|
||
|
/// - 3, otherwise
|
||
|
/// - diff_bias
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc (index)), if used
|
||
|
/// and index is:
|
||
|
/// - 2, if neither diff_weights_peephole nor
|
||
|
/// diff_weights_projection is used
|
||
|
/// - 3, if one of diff_weights_peephole or diff_weights_projection
|
||
|
/// is used
|
||
|
/// - 4, if both diff_weights_peephole and diff_weights_projection
|
||
|
/// are used
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Must be
|
||
|
/// #dnnl::prop_kind::backward.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
||
|
/// applied to the cell states (according to the Peephole LSTM
|
||
|
/// formula).
|
||
|
/// @param weights_projection_desc Memory descriptor for the weights
|
||
|
/// applied to the hidden states to get the recurrent projection
|
||
|
/// (according to the Projection LSTM formula).
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
||
|
/// vector.
|
||
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
|
||
|
/// input recurrent cell state vector.
|
||
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the layer input.
|
||
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the recurrent input.
|
||
|
/// @param diff_weights_peephole_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the cell states (according to the Peephole
|
||
|
/// LSTM formula).
|
||
|
/// @param diff_weights_projection_desc Memory descriptor for the diff
|
||
|
/// of weights applied to the hidden states to get the recurrent
|
||
|
/// projection (according to the Projection LSTM formula).
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
||
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
||
|
/// output vector.
|
||
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
|
||
|
/// output recurrent cell state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &src_iter_c_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &weights_peephole_desc,
|
||
|
const memory::desc &weights_projection_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &dst_iter_c_desc,
|
||
|
const memory::desc &diff_src_layer_desc,
|
||
|
const memory::desc &diff_src_iter_desc,
|
||
|
const memory::desc &diff_src_iter_c_desc,
|
||
|
const memory::desc &diff_weights_layer_desc,
|
||
|
const memory::desc &diff_weights_iter_desc,
|
||
|
const memory::desc &diff_weights_peephole_desc,
|
||
|
const memory::desc &diff_weights_projection_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_layer_desc,
|
||
|
const memory::desc &diff_dst_iter_desc,
|
||
|
const memory::desc &diff_dst_iter_c_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lstm_backward_desc_init_v3(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &src_iter_c_desc.data,
|
||
|
&weights_layer_desc.data, &weights_iter_desc.data,
|
||
|
&weights_peephole_desc.data,
|
||
|
&weights_projection_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&dst_iter_c_desc.data, &diff_src_layer_desc.data,
|
||
|
&diff_src_iter_desc.data,
|
||
|
&diff_src_iter_c_desc.data,
|
||
|
&diff_weights_layer_desc.data,
|
||
|
&diff_weights_iter_desc.data,
|
||
|
&diff_weights_peephole_desc.data,
|
||
|
&diff_weights_projection_desc.data,
|
||
|
&diff_bias_desc.data, &diff_dst_layer_desc.data,
|
||
|
&diff_dst_iter_desc.data,
|
||
|
&diff_dst_iter_c_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LSTM backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs an LSTM (with or without peephole) descriptor for
|
||
|
/// backward propagation using @p prop_kind, @p direction, and memory
|
||
|
/// descriptors.
|
||
|
///
|
||
|
/// The @p src_iter_desc together with @p diff_iter_desc, @p
|
||
|
/// src_iter_c_desc together with @p src_iter_c_desc, @p
|
||
|
/// weights_peephole_desc together with @p diff_weights_peephole_desc,
|
||
|
/// @p bias_desc together with @p diff_bias_desc, @p dst_iter_desc
|
||
|
/// together with @p diff_dst_iter_desc, and @p dst_iter_c_desc
|
||
|
/// together with @p diff_dst_iter_c_desc, may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the LSTM backward
|
||
|
/// propagation primitive should not use them and should default to
|
||
|
/// zero values instead.
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors may be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `weights_peephole` (#dnnl::primitive_desc_base::weights_desc(`2`)),
|
||
|
/// if used
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used and
|
||
|
/// LSTM is without peephole
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`3`)), if used and
|
||
|
/// LSTM is with peephole
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
|
||
|
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `diff_dst_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
|
||
|
/// - `diff_dst_iter_c`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`2`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_src_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
|
||
|
/// - `diff_src_iter_c`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`2`)), if used
|
||
|
/// - `diff_weights_layer`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_weights_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
/// - `diff_weights_peephole`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`2`)),
|
||
|
/// if used and LSTM is without peephole
|
||
|
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`3`)),
|
||
|
/// if used and LSTM is with peephole
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Must be
|
||
|
/// #dnnl::prop_kind::backward.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
||
|
/// applied to the cell states (according to the Peephole LSTM
|
||
|
/// formula).
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
||
|
/// vector.
|
||
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
|
||
|
/// input recurrent cell state vector.
|
||
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the layer input.
|
||
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the recurrent input.
|
||
|
/// @param diff_weights_peephole_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the cell states (according to the Peephole
|
||
|
/// LSTM formula).
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
||
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
||
|
/// output vector.
|
||
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
|
||
|
/// output recurrent cell state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &src_iter_c_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &weights_peephole_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &dst_iter_c_desc,
|
||
|
const memory::desc &diff_src_layer_desc,
|
||
|
const memory::desc &diff_src_iter_desc,
|
||
|
const memory::desc &diff_src_iter_c_desc,
|
||
|
const memory::desc &diff_weights_layer_desc,
|
||
|
const memory::desc &diff_weights_iter_desc,
|
||
|
const memory::desc &diff_weights_peephole_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_layer_desc,
|
||
|
const memory::desc &diff_dst_iter_desc,
|
||
|
const memory::desc &diff_dst_iter_c_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lstm_backward_desc_init_v2(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &src_iter_c_desc.data,
|
||
|
&weights_layer_desc.data, &weights_iter_desc.data,
|
||
|
&weights_peephole_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&dst_iter_c_desc.data, &diff_src_layer_desc.data,
|
||
|
&diff_src_iter_desc.data,
|
||
|
&diff_src_iter_c_desc.data,
|
||
|
&diff_weights_layer_desc.data,
|
||
|
&diff_weights_iter_desc.data,
|
||
|
&diff_weights_peephole_desc.data,
|
||
|
&diff_bias_desc.data, &diff_dst_layer_desc.data,
|
||
|
&diff_dst_iter_desc.data,
|
||
|
&diff_dst_iter_c_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LSTM backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs an LSTM descriptor for backward propagation using @p
|
||
|
/// prop_kind, @p direction, and memory descriptors.
|
||
|
///
|
||
|
/// The @p src_iter_desc together with @p diff_iter_desc, @p
|
||
|
/// src_iter_c_desc together with @p src_iter_c_desc, @p bias_desc
|
||
|
/// together with @p diff_bias_desc, @p dst_iter_desc together with @p
|
||
|
/// diff_dst_iter_desc, and @p dst_iter_c_desc together with @p
|
||
|
/// diff_dst_iter_c_desc, may point to a zero memory descriptor. This
|
||
|
/// would then indicate that the LSTM backward propagation primitive
|
||
|
/// should not use them and should default to zero values instead.
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors may be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
|
||
|
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `diff_dst_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
|
||
|
/// - `diff_dst_iter_c`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`2`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_src_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
|
||
|
/// - `diff_src_iter_c`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`2`)), if used
|
||
|
/// - `diff_weights_layer`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_weights_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
/// - `diff_bias`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Must be
|
||
|
/// #dnnl::prop_kind::backward.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
||
|
/// cell state vector.
|
||
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
||
|
/// vector.
|
||
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
|
||
|
/// input recurrent cell state vector.
|
||
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the layer input.
|
||
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the recurrent input.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
||
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
||
|
/// output vector.
|
||
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
|
||
|
/// output recurrent cell state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &src_iter_c_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &dst_iter_c_desc,
|
||
|
const memory::desc &diff_src_layer_desc,
|
||
|
const memory::desc &diff_src_iter_desc,
|
||
|
const memory::desc &diff_src_iter_c_desc,
|
||
|
const memory::desc &diff_weights_layer_desc,
|
||
|
const memory::desc &diff_weights_iter_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_layer_desc,
|
||
|
const memory::desc &diff_dst_iter_desc,
|
||
|
const memory::desc &diff_dst_iter_c_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lstm_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &src_iter_c_desc.data,
|
||
|
&weights_layer_desc.data, &weights_iter_desc.data,
|
||
|
&bias_desc.data, &dst_layer_desc.data,
|
||
|
&dst_iter_desc.data, &dst_iter_c_desc.data,
|
||
|
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
|
||
|
&diff_src_iter_c_desc.data,
|
||
|
&diff_weights_layer_desc.data,
|
||
|
&diff_weights_iter_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
|
||
|
&diff_dst_iter_c_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LSTM backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for LSTM backward propagation.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LSTM backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for LSTM backward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an LSTM
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const lstm_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LSTM backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LSTM backward propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an LSTM
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const lstm_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, &attr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LSTM backward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for an LSTM backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
||
|
dnnl::algorithm::vanilla_lstm) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_c_desc() const {
|
||
|
return rnn_base::src_iter_c_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
|
||
|
memory::desc weights_peephole_desc() const {
|
||
|
return rnn_base::weights_peephole_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
|
||
|
memory::desc weights_projection_desc() const {
|
||
|
return rnn_base::weights_projection_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc dst_iter_c_desc() const {
|
||
|
return rnn_base::dst_iter_c_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
||
|
memory::desc diff_src_layer_desc() const {
|
||
|
return rnn_base::diff_src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
||
|
memory::desc diff_src_iter_desc() const {
|
||
|
return rnn_base::diff_src_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
|
||
|
memory::desc diff_src_iter_c_desc() const {
|
||
|
return rnn_base::diff_src_iter_c_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
||
|
memory::desc diff_weights_layer_desc() const {
|
||
|
return rnn_base::diff_weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
||
|
memory::desc diff_weights_iter_desc() const {
|
||
|
return rnn_base::diff_weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
|
||
|
memory::desc diff_weights_peephole_desc() const {
|
||
|
return rnn_base::diff_weights_peephole_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
|
||
|
memory::desc diff_weights_projection_desc() const {
|
||
|
return rnn_base::diff_weights_projection_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return rnn_base::diff_bias_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
||
|
memory::desc diff_dst_layer_desc() const {
|
||
|
return rnn_base::diff_dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
||
|
memory::desc diff_dst_iter_desc() const {
|
||
|
return rnn_base::diff_dst_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
|
||
|
memory::desc diff_dst_iter_c_desc() const {
|
||
|
return rnn_base::diff_dst_iter_c_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
lstm_backward() = default;
|
||
|
|
||
|
/// Constructs an LSTM backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an LSTM backward propagation
|
||
|
/// primitive.
|
||
|
lstm_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// GRU forward propagation primitive.
|
||
|
struct gru_forward : public primitive {
|
||
|
/// Descriptor for a GRU forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a GRU forward propagation primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc, @p bias_desc, and @p dst_iter, may point to
|
||
|
/// a zero memory descriptor. This would then indicate that the GRU
|
||
|
/// forward propagation primitive should not use them and should
|
||
|
/// default to zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
|
||
|
/// must be queried for using @ref
|
||
|
/// dnnl::primitive_desc_base::query_md() after a corresponding
|
||
|
/// primitive descriptor is created
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors except @p src_iter_desc may be
|
||
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
||
|
/// format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_gru_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &weights_layer_desc.data,
|
||
|
&weights_iter_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for a GRU forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor GRU forward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a GRU forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a GRU forward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a GRU forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a GRU forward propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a GRU forward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a GRU forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference,
|
||
|
dnnl::algorithm::vanilla_gru) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
gru_forward() = default;
|
||
|
|
||
|
/// Constructs a GRU forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a GRU forward propagation
|
||
|
/// primitive.
|
||
|
gru_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// GRU backward propagation primitive.
|
||
|
struct gru_backward : public primitive {
|
||
|
/// Descriptor for a GRU backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a GRU backward propagation primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc together with @p diff_src_iter_desc, @p
|
||
|
/// bias_desc together with @p diff_bias_desc, and @p dst_iter
|
||
|
/// together with @p diff_dst_iter, may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the GRU backward
|
||
|
/// propagation primitive should not use them and should default to
|
||
|
/// zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `diff_dst_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_src_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
|
||
|
/// - `diff_weights_layer`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_weights_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
/// - `diff_bias`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors may be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Must be
|
||
|
/// #dnnl::prop_kind::backward.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
||
|
/// vector.
|
||
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the layer input.
|
||
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the recurrent input.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
||
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
||
|
/// output vector.
|
||
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &diff_src_layer_desc,
|
||
|
const memory::desc &diff_src_iter_desc,
|
||
|
const memory::desc &diff_weights_layer_desc,
|
||
|
const memory::desc &diff_weights_iter_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_layer_desc,
|
||
|
const memory::desc &diff_dst_iter_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_gru_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &weights_layer_desc.data,
|
||
|
&weights_iter_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
|
||
|
&diff_weights_layer_desc.data,
|
||
|
&diff_weights_iter_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for a GRU backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a GRU backward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a GRU backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a GRU backward propagation primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a GRU
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const gru_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a GRU backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a GRU backward propagation primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a GRU
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const gru_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, &attr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a GRU backward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a GRU backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
||
|
dnnl::algorithm::vanilla_gru) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
||
|
memory::desc diff_src_layer_desc() const {
|
||
|
return rnn_base::diff_src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
||
|
memory::desc diff_src_iter_desc() const {
|
||
|
return rnn_base::diff_src_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
||
|
memory::desc diff_weights_layer_desc() const {
|
||
|
return rnn_base::diff_weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
||
|
memory::desc diff_weights_iter_desc() const {
|
||
|
return rnn_base::diff_weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return rnn_base::diff_bias_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
||
|
memory::desc diff_dst_layer_desc() const {
|
||
|
return rnn_base::diff_dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
||
|
memory::desc diff_dst_iter_desc() const {
|
||
|
return rnn_base::diff_dst_iter_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
gru_backward() = default;
|
||
|
|
||
|
/// Constructs a GRU backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a GRU backward propagation
|
||
|
/// primitive.
|
||
|
gru_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// LBR GRU forward propagation primitive.
|
||
|
struct lbr_gru_forward : public primitive {
|
||
|
/// Descriptor for an LBR GRU forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for LBR GRU forward propagation primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc, @p bias_desc, and @p dst_iter, may point to
|
||
|
/// a zero memory descriptor. This would then indicate that the LBR
|
||
|
/// GRU forward propagation primitive should not use them and should
|
||
|
/// default to zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
|
||
|
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
|
||
|
/// must be queried for using @ref
|
||
|
/// dnnl::primitive_desc_base::query_md() after a corresponding
|
||
|
/// primitive descriptor is created
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors except @p src_iter_desc may be
|
||
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
||
|
/// format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lbr_gru_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &weights_layer_desc.data,
|
||
|
&weights_iter_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LBR GRU forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an LBR GRU forward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a LBR GRU forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a LBR GRU forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a LBR GRU forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a LBR GRU forward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a LBR GRU forward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a LBR GRU forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference,
|
||
|
dnnl::algorithm::lbr_gru) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
lbr_gru_forward() = default;
|
||
|
|
||
|
/// Constructs an LBR GRU forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an LBR GRU forward propagation
|
||
|
/// primitive.
|
||
|
lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// LBR GRU backward propagation primitive.
|
||
|
struct lbr_gru_backward : public primitive {
|
||
|
/// Descriptor for a LBR GRU backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_rnn_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for LBR GRU backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// The @p src_iter_desc together with @p diff_src_iter_desc, @p
|
||
|
/// bias_desc together with @p diff_bias_desc, and @p dst_iter
|
||
|
/// together with @p diff_dst_iter, may point to a zero memory
|
||
|
/// descriptor. This would then indicate that the LBR GRU backward
|
||
|
/// propagation primitive should not use them and should default to
|
||
|
/// zero values instead.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
|
||
|
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
|
||
|
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
|
||
|
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
/// - `diff_dst_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
|
||
|
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
/// - `diff_src_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
|
||
|
/// - `diff_weights_layer`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
|
||
|
/// - `diff_weights_iter`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
|
||
|
/// - `diff_bias`
|
||
|
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
|
||
|
///
|
||
|
/// @note
|
||
|
/// All memory descriptors may be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Must be
|
||
|
/// #dnnl::prop_kind::backward.
|
||
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
||
|
/// more info.
|
||
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
||
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param weights_layer_desc Memory descriptor for the weights
|
||
|
/// applied to the layer input.
|
||
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
||
|
/// to the recurrent input.
|
||
|
/// @param bias_desc Bias memory descriptor.
|
||
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
||
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
||
|
/// hidden state vector.
|
||
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
||
|
/// vector.
|
||
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the layer input.
|
||
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
||
|
/// weights applied to the recurrent input.
|
||
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
||
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
||
|
/// output vector.
|
||
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
||
|
/// recurrent hidden state vector.
|
||
|
/// @param flags Unused.
|
||
|
desc(prop_kind prop_kind, rnn_direction direction,
|
||
|
const memory::desc &src_layer_desc,
|
||
|
const memory::desc &src_iter_desc,
|
||
|
const memory::desc &weights_layer_desc,
|
||
|
const memory::desc &weights_iter_desc,
|
||
|
const memory::desc &bias_desc,
|
||
|
const memory::desc &dst_layer_desc,
|
||
|
const memory::desc &dst_iter_desc,
|
||
|
const memory::desc &diff_src_layer_desc,
|
||
|
const memory::desc &diff_src_iter_desc,
|
||
|
const memory::desc &diff_weights_layer_desc,
|
||
|
const memory::desc &diff_weights_iter_desc,
|
||
|
const memory::desc &diff_bias_desc,
|
||
|
const memory::desc &diff_dst_layer_desc,
|
||
|
const memory::desc &diff_dst_iter_desc,
|
||
|
rnn_flags flags = rnn_flags::undef) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_lbr_gru_backward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
dnnl::convert_to_c(direction), &src_layer_desc.data,
|
||
|
&src_iter_desc.data, &weights_layer_desc.data,
|
||
|
&weights_iter_desc.data, &bias_desc.data,
|
||
|
&dst_layer_desc.data, &dst_iter_desc.data,
|
||
|
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
|
||
|
&diff_weights_layer_desc.data,
|
||
|
&diff_weights_iter_desc.data, &diff_bias_desc.data,
|
||
|
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
|
||
|
dnnl::convert_to_c(flags)),
|
||
|
"could not create a descriptor for an LBR GRU backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an LBR GRU backward propagation primitive.
|
||
|
struct primitive_desc : public rnn_primitive_desc_base {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LBR GRU backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LBR GRU backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an LBR GRU
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const lbr_gru_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an LBR GRU backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an LBR GRU backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for an LBR GRU
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const lbr_gru_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: rnn_primitive_desc_base(&desc.data, &attr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a LBR GRU backward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a LBR GRU backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: rnn_primitive_desc_base(
|
||
|
pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
||
|
memory::desc src_layer_desc() const {
|
||
|
return rnn_base::src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
||
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
||
|
memory::desc weights_layer_desc() const {
|
||
|
return rnn_base::weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
||
|
memory::desc weights_iter_desc() const {
|
||
|
return rnn_base::weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
||
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
||
|
memory::desc dst_layer_desc() const {
|
||
|
return rnn_base::dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
||
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
||
|
memory::desc workspace_desc() const {
|
||
|
return rnn_base::workspace_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
||
|
memory::desc diff_src_layer_desc() const {
|
||
|
return rnn_base::diff_src_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
||
|
memory::desc diff_src_iter_desc() const {
|
||
|
return rnn_base::diff_src_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
||
|
memory::desc diff_weights_layer_desc() const {
|
||
|
return rnn_base::diff_weights_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
||
|
memory::desc diff_weights_iter_desc() const {
|
||
|
return rnn_base::diff_weights_iter_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
||
|
memory::desc diff_bias_desc() const {
|
||
|
return rnn_base::diff_bias_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
||
|
memory::desc diff_dst_layer_desc() const {
|
||
|
return rnn_base::diff_dst_layer_desc();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
||
|
memory::desc diff_dst_iter_desc() const {
|
||
|
return rnn_base::diff_dst_iter_desc();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
lbr_gru_backward() = default;
|
||
|
|
||
|
/// Constructs an LBR GRU backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for an LBR GRU backward propagation
|
||
|
/// primitive.
|
||
|
lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_rnn
|
||
|
|
||
|
/// @addtogroup dnnl_api_shuffle Shuffle
|
||
|
///
|
||
|
/// A primitive to shuffle tensor data along an axis.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_shuffle in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Shuffle forward propagation primitive.
|
||
|
struct shuffle_forward : public primitive {
|
||
|
/// Descriptor for a shuffle forward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_shuffle_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a shuffle forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param data_desc Source and destination memory descriptor.
|
||
|
/// @param axis The axis along which the data is shuffled.
|
||
|
/// @param group_size Shuffle group size.
|
||
|
desc(prop_kind prop_kind, const memory::desc &data_desc, int axis,
|
||
|
int group_size) {
|
||
|
error::wrap_c_api(dnnl_shuffle_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
&data_desc.data, axis, group_size),
|
||
|
"could not create a descriptor for a shuffle forward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a shuffle forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a shuffle forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a shuffle forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const primitive_attr &attr = primitive_attr(),
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a shuffle forward propagation
|
||
|
/// primitive from a C API primitive descriptor that must have a
|
||
|
/// matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a shuffle forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
shuffle_forward() = default;
|
||
|
|
||
|
/// Constructs a shuffle forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a shuffle forward propagation
|
||
|
/// primitive.
|
||
|
shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Shuffle backward propagation primitive.
|
||
|
struct shuffle_backward : public primitive {
|
||
|
/// Descriptor for a shuffle primitive backward propagation
|
||
|
/// primitive.
|
||
|
struct desc {
|
||
|
dnnl_shuffle_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a shuffle backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param diff_data_desc Diff source and diff destination memory
|
||
|
/// descriptor.
|
||
|
/// @param axis The axis along which the data is shuffled.
|
||
|
/// @param group_size Shuffle group size.
|
||
|
desc(const memory::desc &diff_data_desc, int axis, int group_size) {
|
||
|
error::wrap_c_api(dnnl_shuffle_backward_desc_init(&data,
|
||
|
&diff_data_desc.data, axis, group_size),
|
||
|
"could not create a descriptor for a shuffle backward "
|
||
|
"propagation primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a shuffle backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a shuffle backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a shuffle backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a shuffle
|
||
|
/// forward propagation primitive. It is used as a hint for
|
||
|
/// deciding which memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const shuffle_forward::primitive_desc &hint_fwd_pd,
|
||
|
const primitive_attr &attr = primitive_attr(),
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a shuffle backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a shuffle backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
shuffle_backward() = default;
|
||
|
|
||
|
/// Constructs a shuffle backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a shuffle backward propagation
|
||
|
/// primitive.
|
||
|
shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_shuffle
|
||
|
|
||
|
/// @addtogroup dnnl_api_binary Binary
|
||
|
///
|
||
|
/// A primitive to perform tensor operations over two tensors.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_binary in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Elementwise binary operator primitive.
|
||
|
struct binary : public primitive {
|
||
|
/// Descriptor for an elementwise binary operator primitive.
|
||
|
struct desc {
|
||
|
/// Underlying C operation descriptor.
|
||
|
dnnl_binary_desc_t data;
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
desc() = default;
|
||
|
|
||
|
/// Constructs a descriptor for an elementwise binary operator
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src0` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `src1` (#dnnl::primitive_desc_base::src_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param algorithm Elementwise algorithm.
|
||
|
/// @param src0 Memory descriptor for source tensor #0.
|
||
|
/// @param src1 Memory descriptor for source tensor #1.
|
||
|
/// @param dst Memory descriptor for destination tensor.
|
||
|
desc(algorithm algorithm, const memory::desc &src0,
|
||
|
const memory::desc &src1, const memory::desc &dst) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_binary_desc_init(&data, dnnl::convert_to_c(algorithm),
|
||
|
&src0.data, &src1.data, &dst.data),
|
||
|
"could not create a descriptor for a binary operation "
|
||
|
"primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for an elementwise binary operator primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for an elementwise binary operator
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an elementwise binary operator primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for an elementwise binary operator
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for an elementwise binary operator primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a binary primitive from a C
|
||
|
/// API primitive descriptor that must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a binary primitve.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
|
||
|
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
|
||
|
|
||
|
/// Returns the memory descriptor for source #0.
|
||
|
memory::desc src0_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// Returns the memory descriptor for source #1.
|
||
|
memory::desc src1_desc() const { return base::src_desc(1); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
binary() = default;
|
||
|
|
||
|
/// Constructs an elementwise binary operation primitive.
|
||
|
/// @param pd Primitive descriptor for an elementwise binary operation
|
||
|
/// primitive.
|
||
|
binary(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_binary
|
||
|
|
||
|
/// @addtogroup dnnl_api_matmul Matrix Multiplication
|
||
|
///
|
||
|
/// A primitive to perform matrix-matrix multiplication. The batched mode
|
||
|
/// is supported with 3D tensors.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_matmul in developer guide
|
||
|
///
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Matrix multiplication (matmul) primitive.
|
||
|
struct matmul : public primitive {
|
||
|
/// Descriptor for a matmul primitive.
|
||
|
struct desc {
|
||
|
dnnl_matmul_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a matmul primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param src_desc Memory descriptor for source (matrix A).
|
||
|
/// @param weights_desc Memory descriptor for weights (matrix B).
|
||
|
/// @param dst_desc Memory descriptor for destination (matrix C).
|
||
|
desc(const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &dst_desc) {
|
||
|
error::wrap_c_api(
|
||
|
dnnl_matmul_desc_init(&data, &src_desc.data,
|
||
|
&weights_desc.data, nullptr, &dst_desc.data),
|
||
|
"could not create a descriptor for a matmul primitive");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a matmul primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
|
||
|
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @param src_desc Memory descriptor for source (matrix A).
|
||
|
/// @param weights_desc Memory descriptor for weights (matrix B).
|
||
|
/// @param dst_desc Memory descriptor for destination (matrix C).
|
||
|
/// @param bias_desc Memory descriptor for bias.
|
||
|
desc(const memory::desc &src_desc, const memory::desc &weights_desc,
|
||
|
const memory::desc &bias_desc, const memory::desc &dst_desc) {
|
||
|
error::wrap_c_api(dnnl_matmul_desc_init(&data, &src_desc.data,
|
||
|
&weights_desc.data, &bias_desc.data,
|
||
|
&dst_desc.data),
|
||
|
"could not create a descriptor for a matmul primitive");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a matmul primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a matmul primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a matmul primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a matmul primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a matmul primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a matmul primitive from a C
|
||
|
/// API primitive descriptor that must have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a matmul primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return query_md(query::src_md, 0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
||
|
memory::desc weights_desc() const {
|
||
|
return query_md(query::weights_md, 0);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
|
||
|
memory::desc bias_desc() const {
|
||
|
return query_md(query::weights_md, 1);
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
matmul() = default;
|
||
|
|
||
|
/// Constructs a matmul primitive.
|
||
|
/// @param pd Primitive descriptor for a matmul primitive.
|
||
|
matmul(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_matmul
|
||
|
|
||
|
/// @addtogroup dnnl_api_resampling Resampling
|
||
|
///
|
||
|
/// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
|
||
|
/// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
|
||
|
/// method.
|
||
|
///
|
||
|
/// @sa @ref dev_guide_resampling in developer guide
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// Resampling forward propagation.
|
||
|
struct resampling_forward : public primitive {
|
||
|
/// Descriptor for resampling forward propagation.
|
||
|
struct desc {
|
||
|
dnnl_resampling_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a resampling forward propagation
|
||
|
/// primitive using source and destination memory descriptors.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// Destination memory descriptor may be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
//
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm resampling algorithm kind: either
|
||
|
/// #dnnl::algorithm::resampling_nearest, or
|
||
|
/// #dnnl::algorithm::resampling_linear
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const memory::desc &src_desc, const memory::desc &dst_desc) {
|
||
|
error::wrap_c_api(dnnl_resampling_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), nullptr,
|
||
|
&src_desc.data, &dst_desc.data),
|
||
|
"could not create a resampling forward descriptor");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a resampling forward propagation
|
||
|
/// primitive using source memory descriptor and factors.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm resampling algorithm kind: either
|
||
|
/// #dnnl::algorithm::resampling_nearest, or
|
||
|
/// #dnnl::algorithm::resampling_linear
|
||
|
/// @param factors Vector of scaling factors for spatial dimension.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const std::vector<float> &factors,
|
||
|
const memory::desc &src_desc) {
|
||
|
memory::validate_dims(factors, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_resampling_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), &factors[0],
|
||
|
&src_desc.data, nullptr),
|
||
|
"could not create a resampling forward descriptor");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for a resampling forward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
|
||
|
///
|
||
|
/// @note
|
||
|
/// Destination memory descriptor may be initialized with
|
||
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
||
|
//
|
||
|
/// @param prop_kind Propagation kind. Possible values are
|
||
|
/// #dnnl::prop_kind::forward_training, and
|
||
|
/// #dnnl::prop_kind::forward_inference.
|
||
|
/// @param algorithm resampling algorithm kind: either
|
||
|
/// #dnnl::algorithm::resampling_nearest, or
|
||
|
/// #dnnl::algorithm::resampling_linear
|
||
|
/// @param factors Vector of scaling factors for spatial dimension.
|
||
|
/// @param src_desc Source memory descriptor.
|
||
|
/// @param dst_desc Destination memory descriptor.
|
||
|
desc(prop_kind prop_kind, algorithm algorithm,
|
||
|
const std::vector<float> &factors, const memory::desc &src_desc,
|
||
|
const memory::desc &dst_desc) {
|
||
|
if (!factors.empty())
|
||
|
memory::validate_dims(factors, src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_resampling_forward_desc_init(&data,
|
||
|
dnnl::convert_to_c(prop_kind),
|
||
|
convert_to_c(algorithm), factors.data(),
|
||
|
&src_desc.data, &dst_desc.data),
|
||
|
"could not create a resampling forward descriptor");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for a resampling forward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a resampling forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a resampling forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, nullptr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a resampling forward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a resampling forward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine, bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(
|
||
|
&desc.data, &attr, engine, nullptr, allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a resampling forward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a resampling forward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
|
||
|
dnnl::prop_kind::forward_training,
|
||
|
dnnl::prop_kind::forward_inference) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
||
|
memory::desc src_desc() const { return base::src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
||
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
resampling_forward() = default;
|
||
|
|
||
|
/// Constructs a resampling forward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a resampling forward propagation
|
||
|
/// primitive.
|
||
|
resampling_forward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// Resampling backward propagation primitive.
|
||
|
struct resampling_backward : public primitive {
|
||
|
/// Descriptor for a resampling backward propagation primitive.
|
||
|
struct desc {
|
||
|
dnnl_resampling_desc_t data;
|
||
|
|
||
|
/// Constructs a descriptor for a resampling backward propagation
|
||
|
/// primitive using source and destination memory descriptors.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param algorithm resampling algorithm kind: either
|
||
|
/// #dnnl::algorithm::resampling_nearest, or
|
||
|
/// #dnnl::algorithm::resampling_linear
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
desc(algorithm algorithm, const memory::desc &diff_src_desc,
|
||
|
const memory::desc &diff_dst_desc) {
|
||
|
error::wrap_c_api(dnnl_resampling_backward_desc_init(&data,
|
||
|
convert_to_c(algorithm), nullptr,
|
||
|
&diff_src_desc.data, &diff_dst_desc.data),
|
||
|
"could not create a resampling backward data descriptor");
|
||
|
}
|
||
|
|
||
|
/// Constructs a descriptor for resampling backward propagation
|
||
|
/// primitive.
|
||
|
///
|
||
|
/// Inputs:
|
||
|
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
|
||
|
///
|
||
|
/// Outputs:
|
||
|
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
|
||
|
///
|
||
|
/// @param algorithm resampling algorithm kind: either
|
||
|
/// #dnnl::algorithm::resampling_nearest, or
|
||
|
/// #dnnl::algorithm::resampling_linear
|
||
|
/// @param factors Vector of scaling factors for spatial dimension.
|
||
|
/// @param diff_src_desc Diff source memory descriptor.
|
||
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
||
|
desc(algorithm algorithm, const std::vector<float> &factors,
|
||
|
const memory::desc &diff_src_desc,
|
||
|
const memory::desc &diff_dst_desc) {
|
||
|
if (!factors.empty())
|
||
|
memory::validate_dims(factors, diff_src_desc.data.ndims - 2);
|
||
|
error::wrap_c_api(dnnl_resampling_backward_desc_init(&data,
|
||
|
convert_to_c(algorithm), factors.data(),
|
||
|
&diff_src_desc.data, &diff_dst_desc.data),
|
||
|
"could not create a resampling backward data descriptor");
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// Primitive descriptor for resampling backward propagation primitive.
|
||
|
struct primitive_desc : public dnnl::primitive_desc {
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
primitive_desc() = default;
|
||
|
|
||
|
/// Constructs a primitive descriptor for a resampling backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a resampling backward propagation
|
||
|
/// primitive.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a resampling forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const engine &engine,
|
||
|
const resampling_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, nullptr, engine,
|
||
|
hint_fwd_pd.get(), allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a resampling backward
|
||
|
/// propagation primitive.
|
||
|
///
|
||
|
/// @param desc Descriptor for a resampling backward propagation
|
||
|
/// primitive.
|
||
|
/// @param attr Primitive attributes to use.
|
||
|
/// @param engine Engine to use.
|
||
|
/// @param hint_fwd_pd Primitive descriptor for a resampling forward
|
||
|
/// propagation primitive. It is used as a hint for deciding which
|
||
|
/// memory format to use.
|
||
|
/// @param allow_empty A flag signifying whether construction is
|
||
|
/// allowed to fail without throwing an exception. In this case an
|
||
|
/// empty object will be produced. This flag is optional and
|
||
|
/// defaults to false.
|
||
|
primitive_desc(const desc &desc, const primitive_attr &attr,
|
||
|
const engine &engine,
|
||
|
const resampling_forward::primitive_desc &hint_fwd_pd,
|
||
|
bool allow_empty = false)
|
||
|
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
|
||
|
allow_empty) {}
|
||
|
|
||
|
/// Constructs a primitive descriptor for a resampling backward
|
||
|
/// propagation primitive from a C API primitive descriptor that must
|
||
|
/// have a matching kind.
|
||
|
///
|
||
|
/// @param pd C API primitive descriptor for a resampling backward
|
||
|
/// propagation primitive.
|
||
|
primitive_desc(dnnl_primitive_desc_t pd)
|
||
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
|
||
|
dnnl::prop_kind::backward_data) {}
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
||
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
||
|
|
||
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
||
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
||
|
};
|
||
|
|
||
|
/// Default constructor. Produces an empty object.
|
||
|
resampling_backward() = default;
|
||
|
|
||
|
/// Constructs a resampling backward propagation primitive.
|
||
|
/// @param pd Primitive descriptor for a resampling backward propagation
|
||
|
/// primitive.
|
||
|
resampling_backward(const primitive_desc &pd) : primitive(pd) {}
|
||
|
};
|
||
|
|
||
|
/// @} dnnl_api_resampling
|
||
|
|
||
|
/// @} dnnl_api_primitives
|
||
|
|
||
|
/// @addtogroup dnnl_api_service Service
|
||
|
///
|
||
|
/// A set of functions that aid in oneDNN debugging and profiling.
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// @copydoc dnnl_version_t
|
||
|
using version_t = dnnl_version_t;
|
||
|
|
||
|
/// Status values returned by the library functions.
|
||
|
enum class status {
|
||
|
/// @copydoc dnnl_success
|
||
|
success = dnnl_success,
|
||
|
/// @copydoc dnnl_out_of_memory
|
||
|
out_of_memory = dnnl_out_of_memory,
|
||
|
/// @copydoc dnnl_invalid_arguments
|
||
|
invalid_arguments = dnnl_invalid_arguments,
|
||
|
/// @copydoc dnnl_unimplemented
|
||
|
unimplemented = dnnl_unimplemented,
|
||
|
/// @copydoc dnnl_iterator_ends
|
||
|
iterator_ends = dnnl_iterator_ends,
|
||
|
/// @copydoc dnnl_runtime_error
|
||
|
runtime_error = dnnl_runtime_error,
|
||
|
/// @copydoc dnnl_not_required
|
||
|
not_required = dnnl_not_required,
|
||
|
};
|
||
|
|
||
|
/// @copydoc dnnl_set_verbose()
|
||
|
inline status set_verbose(int level) {
|
||
|
return static_cast<status>(dnnl_set_verbose(level));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_version()
|
||
|
inline const version_t *version() {
|
||
|
return dnnl_version();
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_set_jit_dump()
|
||
|
inline status set_jit_dump(int enable) {
|
||
|
return static_cast<status>(dnnl_set_jit_dump(enable));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_set_jit_profiling_flags()
|
||
|
inline status set_jit_profiling_flags(unsigned flags) {
|
||
|
return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_set_jit_profiling_jitdumpdir()
|
||
|
inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
|
||
|
return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_cpu_isa_t
|
||
|
enum class cpu_isa {
|
||
|
/// @copydoc dnnl_cpu_isa_all
|
||
|
all = dnnl_cpu_isa_all,
|
||
|
/// @copydoc dnnl_cpu_isa_sse41
|
||
|
sse41 = dnnl_cpu_isa_sse41,
|
||
|
/// @copydoc dnnl_cpu_isa_avx
|
||
|
avx = dnnl_cpu_isa_avx,
|
||
|
/// @copydoc dnnl_cpu_isa_avx2
|
||
|
avx2 = dnnl_cpu_isa_avx2,
|
||
|
/// @copydoc dnnl_cpu_isa_avx512_mic
|
||
|
avx512_mic = dnnl_cpu_isa_avx512_mic,
|
||
|
/// @copydoc dnnl_cpu_isa_avx512_mic_4ops
|
||
|
avx512_mic_4ops = dnnl_cpu_isa_avx512_mic_4ops,
|
||
|
/// @copydoc dnnl_cpu_isa_avx512_core
|
||
|
avx512_core = dnnl_cpu_isa_avx512_core,
|
||
|
/// @copydoc dnnl_cpu_isa_avx512_core_vnni
|
||
|
avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
|
||
|
/// @copydoc dnnl_cpu_isa_avx512_core_bf16
|
||
|
avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
|
||
|
};
|
||
|
|
||
|
/// @copydoc dnnl_set_max_cpu_isa()
|
||
|
inline status set_max_cpu_isa(cpu_isa isa) {
|
||
|
return static_cast<status>(
|
||
|
dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
|
||
|
}
|
||
|
|
||
|
/// @} dnnl_api_service
|
||
|
|
||
|
/// @addtogroup dnnl_api_blas BLAS functions
|
||
|
///
|
||
|
/// A subset of Basic Linear ALgebra (BLAS) functions that perform
|
||
|
/// matrix-matrix multiplication.
|
||
|
///
|
||
|
/// @{
|
||
|
|
||
|
/// @copydoc dnnl_sgemm()
|
||
|
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
|
||
|
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
|
||
|
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
|
||
|
return static_cast<status>(dnnl_sgemm(
|
||
|
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_gemm_u8s8s32()
|
||
|
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
||
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
|
||
|
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
||
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
|
||
|
return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
|
||
|
K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_gemm_s8s8s32()
|
||
|
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
||
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
|
||
|
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
||
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
|
||
|
return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
|
||
|
K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
|
||
|
}
|
||
|
|
||
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
|
||
|
/// @copydoc dnnl_sgemm_tp()
|
||
|
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
|
||
|
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
|
||
|
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc,
|
||
|
dnnl::threadpool_iface *tp) {
|
||
|
return static_cast<status>(dnnl_sgemm_tp(
|
||
|
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, tp));
|
||
|
}
|
||
|
/// @copydoc dnnl_gemm_u8s8s32_tp()
|
||
|
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
||
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
|
||
|
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
||
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
|
||
|
dnnl::threadpool_iface *tp) {
|
||
|
return static_cast<status>(dnnl_gemm_u8s8s32_tp(transa, transb, offsetc, M,
|
||
|
N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co, tp));
|
||
|
}
|
||
|
|
||
|
/// @copydoc dnnl_gemm_s8s8s32_tp()
|
||
|
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
||
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
|
||
|
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
||
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
|
||
|
dnnl::threadpool_iface *tp) {
|
||
|
return static_cast<status>(dnnl_gemm_s8s8s32_tp(transa, transb, offsetc, M,
|
||
|
N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co, tp));
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
/// @} dnnl_api_blas
|
||
|
|
||
|
// implementation section
|
||
|
|
||
|
/// @cond DO_NOT_DOCUMENT_THIS
|
||
|
inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
|
||
|
dnnl_primitive_t result;
|
||
|
error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
|
||
|
"could not create a primitive");
|
||
|
reset(result);
|
||
|
}
|
||
|
|
||
|
inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
|
||
|
|
||
|
inline void primitive::execute(const stream &stream,
|
||
|
const std::unordered_map<int, memory> &args) const {
|
||
|
std::vector<dnnl_exec_arg_t> c_args;
|
||
|
c_args.reserve(args.size());
|
||
|
for (const auto &a : args)
|
||
|
c_args.push_back({a.first, a.second.get(true)});
|
||
|
|
||
|
error::wrap_c_api(dnnl_primitive_execute(get(), stream.get(),
|
||
|
(int)c_args.size(), c_args.data()),
|
||
|
"could not execute a primitive");
|
||
|
}
|
||
|
/// @endcond
|
||
|
|
||
|
#undef DNNL_DEFINE_BITMASK_OPS
|
||
|
|
||
|
} // namespace dnnl
|
||
|
|
||
|
/// @} dnnl_api
|
||
|
|
||
|
#endif
|