DYT/Tool/matlab/include/mkldnn/dnnl.hpp

10584 lines
474 KiB
C++
Raw Normal View History

2024-11-22 15:19:31 +00:00
/*******************************************************************************
* Copyright 2016-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/// @file
/// C++ API
#ifndef DNNL_HPP
#define DNNL_HPP
#include "dnnl_config.h"
/// @cond DO_NOT_DOCUMENT_THIS
#include <algorithm>
#include <cstdlib>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "dnnl.h"
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
#include "dnnl_threadpool_iface.hpp"
#endif
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
#include <CL/cl.h>
#endif
/// @endcond
// __cpp_exceptions is referred from
// https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
// Microsoft C++ Compiler does not provide an option to disable exceptions
#ifndef DNNL_ENABLE_EXCEPTIONS
#if __cpp_exceptions || __EXCEPTIONS \
|| (defined(_MSC_VER) && !defined(__clang__))
#define DNNL_ENABLE_EXCEPTIONS 1
#else
#define DNNL_ENABLE_EXCEPTIONS 0
#endif
#endif
#if defined(__GNUC__) || defined(__clang__)
#define DNNL_TRAP() __builtin_trap()
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
#define DNNL_TRAP() __debugbreak()
#else
#error "unknown compiler"
#endif
#if DNNL_ENABLE_EXCEPTIONS
#define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
#else
#include <cstdio>
#define DNNL_THROW_ERROR(status, msg) \
do { \
fputs(msg, stderr); \
DNNL_TRAP(); \
} while (0)
#endif
/// @addtogroup dnnl_api oneDNN API
/// @{
/// oneDNN namespace
namespace dnnl {
/// @addtogroup dnnl_api_utils Utilities
/// Utility types and definitions.
/// @{
/// oneDNN exception class.
///
/// This class captures the status returned by a failed C API function and
/// the error message from the call site.
struct error : public std::exception {
dnnl_status_t status;
const char *message;
/// Constructs an instance of an exception class.
///
/// @param status The error status returned by a C API function.
/// @param message The error message.
error(dnnl_status_t status, const char *message)
: status(status), message(message) {}
/// Returns the explanatory string.
const char *what() const noexcept override { return message; }
/// A convenience function for wrapping calls to C API functions. Checks
/// the return status and throws an dnnl::error in case of failure.
///
/// @param status The error status returned by a C API function.
/// @param message The error message.
static void wrap_c_api(dnnl_status_t status, const char *message) {
if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
}
};
/// @cond DO_NOT_DOCUMENT_THIS
template <typename T>
void validate_container_size(const T &v, const char *error_message,
int min_size = 1, int max_size = -1) {
const int size = (int)v.size();
if (size < min_size || (max_size >= 0 && size > max_size))
DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
}
/// @endcond
/// A class that provides the destructor for a oneDNN C API handle.
template <typename T>
struct handle_traits {};
/// oneDNN C API handle wrapper class.
///
/// This class is used as the base class for primitive (dnnl::primitive),
/// engine (dnnl::engine), and stream (dnnl::stream) classes, as well as
/// others. An object of the dnnl::handle class can be passed by value.
///
/// A handle can be weak, in which case it follows std::weak_ptr semantics.
/// Otherwise, it follows `std::shared_ptr` semantics.
///
/// @note
/// The implementation stores oneDNN C API handles in a `std::shared_ptr`
/// with deleter set to a dummy function in the weak mode.
///
template <typename T, typename traits = handle_traits<T>>
struct handle {
private:
static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
protected:
bool operator==(const T other) const { return other == data_.get(); }
bool operator!=(const T other) const { return !(*this == other); }
public:
/// Constructs an empty handle object.
///
/// @warning
/// Uninitialized object cannot be used in most library calls and is
/// equivalent to a null pointer. Any attempt to use its methods, or
/// passing it to the other library function, will cause an exception
/// to be thrown.
handle() = default;
/// Copy constructor.
handle(const handle<T, traits> &) = default;
/// Assignment operator.
handle<T, traits> &operator=(const handle<T, traits> &) = default;
/// Move constructor.
handle(handle<T, traits> &&) = default;
/// Move assignment operator.
handle<T, traits> &operator=(handle<T, traits> &&) = default;
/// Constructs a handle wrapper object from a C API handle.
///
/// @param t The C API handle to wrap.
/// @param weak A flag specifying whether to construct a weak wrapper;
/// defaults to @c false.
explicit handle(T t, bool weak = false) { reset(t, weak); }
/// Resets the handle wrapper objects to wrap a new C API handle.
///
/// @param t The new value of the C API handle.
/// @param weak A flag specifying whether the wrapper should be weak;
/// defaults to @c false.
void reset(T t, bool weak = false) {
data_.reset(t, weak ? &dummy_destructor : traits::destructor);
}
/// Returns the underlying C API handle.
///
/// @param allow_empty A flag signifying whether the method is allowed to
/// return an empty (null) object without throwing an exception.
/// @returns The underlying C API handle.
T get(bool allow_empty = false) const {
T result = data_.get();
if (allow_empty == false && result == nullptr)
DNNL_THROW_ERROR(
dnnl_invalid_arguments, "object is not initialized");
return result;
}
/// Converts a handle to the underlying C API handle type. Does not throw
/// and returns `nullptr` if the object is empty.
///
/// @returns The underlying C API handle.
explicit operator T() const { return get(true); }
/// Checks whether the object is empty.
///
/// @returns Whether the object is empty.
explicit operator bool() const { return get(true) != nullptr; }
/// Equality operator.
///
/// @param other Another handle wrapper.
/// @returns @c true if this and the other handle wrapper manage the same
/// underlying C API handle, and @c false otherwise. Empty handle
/// objects are considered to be equal.
bool operator==(const handle<T, traits> &other) const {
return other.data_.get() == data_.get();
}
/// Inequality operator.
///
/// @param other Another handle wrapper.
/// @returns @c true if this and the other handle wrapper manage different
/// underlying C API handles, and @c false otherwise. Empty handle
/// objects are considered to be equal.
bool operator!=(const handle &other) const { return !(*this == other); }
};
/// @cond DO_NOT_DOCUMENT_THIS
template <>
struct handle_traits<dnnl_memory_t> {
static dnnl_status_t destructor(dnnl_memory_t p) {
return dnnl_memory_destroy(p);
}
};
template <>
struct handle_traits<dnnl_primitive_desc_t> {
static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
return dnnl_primitive_desc_destroy(p);
}
};
template <>
struct handle_traits<dnnl_primitive_t> {
static dnnl_status_t destructor(dnnl_primitive_t p) {
return dnnl_primitive_destroy(p);
}
};
template <>
struct handle_traits<dnnl_primitive_desc_iterator_t> {
static dnnl_status_t destructor(dnnl_primitive_desc_iterator_t p) {
return dnnl_primitive_desc_iterator_destroy(p);
}
};
/// @endcond
/// @} dnnl_api_utils
struct stream;
struct error;
struct memory;
struct primitive_desc;
/// @addtogroup dnnl_api_primitives Primitives
/// Compute primitives
/// @sa @ref dev_guide_basic_concepts
/// @{
/// @addtogroup dnnl_api_primitives_common Common
/// Common operations to create, destroy and inspect primitives
/// @{
/// Base class for all computational primitives.
struct primitive : public handle<dnnl_primitive_t> {
friend struct error;
friend struct stream;
/// Kinds of primitives supported by the library.
enum class kind {
/// Undefined primitive
undef = dnnl_undefined_primitive,
/// A reorder primitive.
reorder = dnnl_reorder,
/// A shuffle primitive.
shuffle = dnnl_shuffle,
/// A (out-of-place) tensor concatenation primitive.
concat = dnnl_concat,
/// A summation primitive.
sum = dnnl_sum,
/// A convolution primitive.
convolution = dnnl_convolution,
/// A deconvolution primitive.
deconvolution = dnnl_deconvolution,
/// An element-wise primitive.
eltwise = dnnl_eltwise,
/// A softmax primitive.
softmax = dnnl_softmax,
/// A pooling primitive.
pooling = dnnl_pooling,
/// An LRN primitive.
lrn = dnnl_lrn,
/// A batch normalization primitive.
batch_normalization = dnnl_batch_normalization,
/// A layer normalization primitive.
layer_normalization = dnnl_layer_normalization,
/// An inner product primitive.
inner_product = dnnl_inner_product,
/// A rnn primitive.
rnn = dnnl_rnn,
/// A binary primitive.
binary = dnnl_binary,
/// A logsoftmax primitive.
logsoftmax = dnnl_logsoftmax,
/// A matmul (matrix multiplication) primitive.
matmul = dnnl_matmul,
/// A resampling primitive.
resampling = dnnl_resampling,
};
using handle::handle;
/// Default constructor. Constructs an empty object.
primitive() = default;
/// Constructs a primitive from a C API primitive descriptor.
///
/// @param c_pd C API primitive descriptor.
primitive(const_dnnl_primitive_desc_t c_pd);
/// Constructs a primitive from a primitive descriptor.
///
/// @param pd Primitive descriptor.
primitive(const primitive_desc &pd);
/// Returns the C API primitive descriptor of the underlying C API
/// primitive.
///
/// @returns The underlying C API primitive descriptor.
inline const_dnnl_primitive_desc_t get_primitive_desc() const;
/// Returns the kind of the primitive.
///
/// @returns The primitive kind.
inline kind get_kind() const;
/// Executes computations specified by the primitive in a specified stream.
///
/// Arguments are passed via an arguments map containing <index,
/// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
/// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
/// matching the one returned by
/// primitive_desc::query_md(#query::exec_arg_md, index) unless using
/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
///
/// @param stream Stream object. The stream must belong to the same engine
/// as the primitive.
/// @param args Arguments map.
void execute(const stream &stream,
const std::unordered_map<int, memory> &args) const;
};
/// Converts primitive kind enum value from C++ API to C API type.
///
/// @param kind C++ API primitive kind enum value.
/// @returns Corresponding C API primitive kind enum value.
inline dnnl_primitive_kind_t convert_to_c(primitive::kind kind) {
return static_cast<dnnl_primitive_kind_t>(kind);
}
const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
const_dnnl_primitive_desc_t pd;
error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
"could not get a primitive descriptor from a primitive");
return pd;
}
dnnl::primitive::kind primitive::get_kind() const {
const_dnnl_primitive_desc_t pd = get_primitive_desc();
// TODO (Roma): the code below is only needed because get_primitive_desc
// returns a C type.
dnnl_primitive_kind_t kind;
error::wrap_c_api(dnnl_primitive_desc_query(
pd, dnnl_query_primitive_kind, 0, (void *)&kind),
"could not get a primitive kind from a primitive descriptor");
return static_cast<dnnl::primitive::kind>(kind);
}
/// @} dnnl_api_primitives_common
/// @addtogroup dnnl_api_attributes
///
/// A container for parameters that extend primitives behavior.
///
/// Attributes can also contain Post-ops, which are computations executed
/// after the primitive.
///
/// @sa @ref dev_guide_attributes
/// @sa @ref dev_guide_attributes_post_ops
///
/// @{
/// Scratchpad mode
enum class scratchpad_mode {
/// The library manages the scratchpad allocation according to the policy
/// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
/// [build option](@ref dev_guide_build_options) (default).
///
/// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
/// scratchpad is common to all primitives to reduce the memory footprint.
/// This configuration comes with limited thread-safety properties, namely
/// primitives can be created and executed in parallel but cannot migrate
/// between threads (in other words, each primitive should be executed in
/// the same thread it was created in).
///
/// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
/// private to each primitive. The memory footprint is larger than when
/// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
/// created and run concurrently (the same primitive cannot be run
/// concurrently from two different threads though).
library = dnnl_scratchpad_mode_library,
/// The user manages the scratchpad allocation by querying and providing
/// the scratchpad memory to primitives. This mode is thread-safe as long
/// as the scratchpad buffers are not used concurrently by two primitive
/// executions.
user = dnnl_scratchpad_mode_user,
};
/// Converts a scratchpad mode enum value from C++ API to C API type.
///
/// @param mode C++ API scratchpad mode enum value.
/// @returns Corresponding C API scratchpad mode enum value.
inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
return static_cast<dnnl_scratchpad_mode_t>(mode);
}
/// Propagation kind.
enum class prop_kind {
/// Undefined propagation kind.
undef = dnnl_prop_kind_undef,
/// Forward data propagation (training mode). In this mode, primitives
/// perform computations necessary for subsequent backward propagation.
forward_training = dnnl_forward_training,
/// Forward data propagation (inference mode). In this mode, primitives
/// perform only computations that are necessary for inference and omit
/// computations that are necessary only for backward propagation.
forward_inference = dnnl_forward_inference,
/// Forward data propagation,
/// alias for #dnnl::prop_kind::forward_inference.
forward_scoring = dnnl_forward_scoring,
/// Forward data propagation,
/// alias for #dnnl::prop_kind::forward_training.
forward = dnnl_forward,
/// Backward propagation (with respect to all parameters).
backward = dnnl_backward,
/// Backward data propagation.
backward_data = dnnl_backward_data,
/// Backward weights propagation.
backward_weights = dnnl_backward_weights,
/// Backward bias propagation.
backward_bias = dnnl_backward_bias
};
/// Converts propagation kind enum value from C++ API to C API type.
///
/// @param kind C++ API propagation kind enum value.
/// @returns Corresponding C API propagation kind enum value.
inline dnnl_prop_kind_t convert_to_c(prop_kind kind) {
return static_cast<dnnl_prop_kind_t>(kind);
}
/// Kinds of algorithms.
enum class algorithm {
/// Undefined algorithm
undef = dnnl_alg_kind_undef,
/// Convolution algorithm that is chosen to be either direct or Winograd
/// automatically
convolution_auto = dnnl_convolution_auto,
/// Direct convolution
convolution_direct = dnnl_convolution_direct,
/// Winograd convolution
convolution_winograd = dnnl_convolution_winograd,
/// Direct deconvolution
deconvolution_direct = dnnl_deconvolution_direct,
/// Winograd deconvolution
deconvolution_winograd = dnnl_deconvolution_winograd,
/// Elementwise: rectified linear unit (ReLU)
eltwise_relu = dnnl_eltwise_relu,
/// Elementwise: hyperbolic tangent non-linearity (tanh)
eltwise_tanh = dnnl_eltwise_tanh,
/// Elementwise: exponential linear unit (ELU)
eltwise_elu = dnnl_eltwise_elu,
/// Elementwise: square
eltwise_square = dnnl_eltwise_square,
/// Elementwise: abs
eltwise_abs = dnnl_eltwise_abs,
/// Elementwise: square root
eltwise_sqrt = dnnl_eltwise_sqrt,
/// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
eltwise_swish = dnnl_eltwise_swish,
/// Elementwise: linear
eltwise_linear = dnnl_eltwise_linear,
/// Elementwise: bounded_relu
eltwise_bounded_relu = dnnl_eltwise_bounded_relu,
/// Elementwise: soft_relu
eltwise_soft_relu = dnnl_eltwise_soft_relu,
/// Elementwise: logistic
eltwise_logistic = dnnl_eltwise_logistic,
/// Elementwise: exponent
eltwise_exp = dnnl_eltwise_exp,
/// Elementwise: gelu
/// alias for #dnnl::algorithm::eltwise_gelu_tanh
eltwise_gelu = dnnl_eltwise_gelu,
/// Elementwise: tanh-based gelu
eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
/// Elementwise: erf-based gelu
eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
/// Elementwise: natural logarithm
eltwise_log = dnnl_eltwise_log,
/// Elementwise: clip
eltwise_clip = dnnl_eltwise_clip,
/// Elementwise: pow
eltwise_pow = dnnl_eltwise_pow,
/// Elementwise: rectified linar unit (ReLU) (dst for backward)
eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
/// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
/// Elementwise: exponential linear unit (ELU) (dst for backward)
eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
/// Elementwise: square root (dst for backward)
eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
/// Elementwise: logistic (dst for backward)
eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
/// Elementwise: exponent (dst for backward)
eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
/// Local response normalization (LRN) across multiple channels
lrn_across_channels = dnnl_lrn_across_channels,
/// LRN within a single channel
lrn_within_channel = dnnl_lrn_within_channel,
/// Max pooling
pooling_max = dnnl_pooling_max,
/// Average pooling exclude padding,
/// alias for #dnnl::algorithm::pooling_avg_include_padding
pooling_avg = dnnl_pooling_avg,
/// Average pooling include padding
pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
/// Average pooling exclude padding
pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
/// RNN cell
vanilla_rnn = dnnl_vanilla_rnn,
/// LSTM cell
vanilla_lstm = dnnl_vanilla_lstm,
/// GRU cell
vanilla_gru = dnnl_vanilla_gru,
/// GRU cell with linear before reset. Differs from the vanilla GRU
/// in how the new memory gate is calculated:
/// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
/// LRB GRU expects 4 bias tensors on input:
/// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
lbr_gru = dnnl_lbr_gru,
/// Binary add
binary_add = dnnl_binary_add,
/// Binary mul
binary_mul = dnnl_binary_mul,
/// Binary max
binary_max = dnnl_binary_max,
/// Binary min
binary_min = dnnl_binary_min,
/// Nearest Neighbor resampling method
resampling_nearest = dnnl_resampling_nearest,
/// Linear (Bilinear, Trilinear) resampling method
resampling_linear = dnnl_resampling_linear,
};
/// Converts algorithm kind enum value from C++ API to C API type.
/// @param algorithm C++ API algorithm kind enum value.
/// @returns Corresponding C API algorithm kind enum value.
inline dnnl_alg_kind_t convert_to_c(algorithm algorithm) {
return static_cast<dnnl_alg_kind_t>(algorithm);
}
/// @} dnnl_api_attributes
/// @addtogroup dnnl_api_primitives_common
/// @{
/// Flags for normalization primitives.
enum class normalization_flags : unsigned {
/// Use no normalization flags. If specified, the library computes mean and
/// variance on forward propagation for training and inference, outputs them
/// on forward propagation for training, and computes the respective
/// derivatives on backward propagation.
none = dnnl_normalization_flags_none,
/// Use global statistics. If specified, the library uses mean and
/// variance provided by the user as an input on forward propagation and
/// does not compute their derivatives on backward propagation. Otherwise,
/// the library computes mean and variance on forward propagation for
/// training and inference, outputs them on forward propagation for
/// training, and computes the respective derivatives on backward
/// propagation.
use_global_stats = dnnl_use_global_stats,
/// Use scale and shift parameters. If specified, the user is expected to
/// pass scale and shift as inputs on forward propagation. On backward
/// propagation of type #dnnl::prop_kind::backward, the library computes
/// their derivatives. If not specified, the scale and shift parameters
/// are not used by the library in any way.
use_scale_shift = dnnl_use_scaleshift,
/// Fuse normalization with ReLU. On training, normalization will require
/// the workspace to implement backward propagation. On inference, the
/// workspace is not required and behavior is the same as when normalization
/// is fused with ReLU using the post-ops API.
fuse_norm_relu = dnnl_fuse_norm_relu
};
/// Converts normalization flags enum value from C++ API to C API type.
/// @param flags C++ API normalization flags enum value.
/// @returns Corresponding C API normalization flags enum value.
inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
return static_cast<dnnl_normalization_flags_t>(flags);
}
/// @} dnnl_api_primitives_common
/// @addtogroup dnnl_api_rnn
/// @{
/// RNN cell flags.
enum class rnn_flags : unsigned {
/// Undefined RNN flags
undef = dnnl_rnn_flags_undef
};
/// Converts RNN cell flags enum value from C++ API to C API type.
/// @param flags C++ API RNN cell flags enum value.
/// @returns Corresponding C API RNN cell flags enum value.
inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
return static_cast<dnnl_rnn_flags_t>(flags);
}
#define DNNL_DEFINE_BITMASK_OPS(enum_name) \
inline enum_name operator|(enum_name lhs, enum_name rhs) { \
return static_cast<enum_name>( \
static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
} \
\
inline enum_name operator&(enum_name lhs, enum_name rhs) { \
return static_cast<enum_name>( \
static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
} \
\
inline enum_name operator^(enum_name lhs, enum_name rhs) { \
return static_cast<enum_name>( \
static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
} \
\
inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
lhs = static_cast<enum_name>( \
static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
return lhs; \
} \
\
inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
lhs = static_cast<enum_name>( \
static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
return lhs; \
} \
\
inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
lhs = static_cast<enum_name>( \
static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
return lhs; \
} \
\
inline enum_name operator~(enum_name rhs) { \
return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
}
DNNL_DEFINE_BITMASK_OPS(normalization_flags)
DNNL_DEFINE_BITMASK_OPS(rnn_flags)
/// A direction of RNN primitive execution
enum class rnn_direction {
/// Unidirectional execution of RNN primitive from left to right.
unidirectional_left2right = dnnl_unidirectional_left2right,
/// Unidirectional execution of RNN primitive from right to left.
unidirectional_right2left = dnnl_unidirectional_right2left,
/// Bidirectional execution of RNN primitive with concatenation of the
/// results.
bidirectional_concat = dnnl_bidirectional_concat,
/// Bidirectional execution of RNN primitive with summation of the
/// results.
bidirectional_sum = dnnl_bidirectional_sum,
/// Alias for #dnnl::rnn_direction::unidirectional_left2right
unidirectional = dnnl_unidirectional,
};
/// Converts RNN direction enum value from C++ API to C API type.
/// @param dir C++ API RNN direction enum value.
/// @returns Corresponding C API RNN direction enum value.
inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
return static_cast<dnnl_rnn_direction_t>(dir);
}
/// @} dnnl_api_rnn
/// @addtogroup dnnl_api_primitives_common
/// @{
/// Primitive descriptor query specification.
///
/// In general, queries are not used with the C++ API because most queries are
/// implemented as class members.
///
/// See @ref dnnl_query_t for more information.
enum class query {
/// no query
undef = dnnl_query_undef,
/// execution engine
engine = dnnl_query_engine,
/// primitive kind
primitive_kind = dnnl_query_primitive_kind,
/// number of inputs expected
num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
/// number of outputs expected
num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
/// runtime estimation (seconds), unimplemented
time_estimate_f64 = dnnl_query_time_estimate_f64,
/// memory consumption (bytes)
///
/// extra (scratch) memory, additional to all inputs and outputs memory
///
/// @sa @ref dev_guide_attributes_scratchpad
memory_consumption_s64 = dnnl_query_memory_consumption_s64,
/// scratchpad engine
///
/// engine to be used for creating scratchpad memory
scratchpad_engine = dnnl_query_scratchpad_engine,
/// reorder source engine
reorder_src_engine = dnnl_query_reorder_src_engine,
/// reorder destination engine
reorder_dst_engine = dnnl_query_reorder_dst_engine,
/// implementation name
impl_info_str = dnnl_query_impl_info_str,
/// propagation kind
prop_kind = dnnl_query_prop_kind,
/// operation descriptor
op_d = dnnl_query_op_d,
/// convolution descriptor
convolution_d = dnnl_query_convolution_d,
/// deconvolution descriptor
deconvolution_d = dnnl_query_deconvolution_d,
/// shuffle descriptor
shuffle_d = dnnl_query_shuffle_d,
/// eltwise descriptor
eltwise_d = dnnl_query_eltwise_d,
/// softmax descriptor
softmax_d = dnnl_query_softmax_d,
/// pooling descriptor
pooling_d = dnnl_query_pooling_d,
/// lrn descriptor
lrn_d = dnnl_query_lrn_d,
/// batch normalization descriptor
batch_normalization_d = dnnl_query_batch_normalization_d,
/// layer normalization descriptor
layer_normalization_d = dnnl_query_layer_normalization_d,
/// inner product descriptor
inner_product_d = dnnl_query_inner_product_d,
/// rnn descriptor
rnn_d = dnnl_query_rnn_d,
/// binary descriptor
binary_d = dnnl_query_binary_d,
/// logsoftmax descriptor
logsoftmax_d = dnnl_query_logsoftmax_d,
/// matmul descriptor
matmul_d = dnnl_query_matmul_d,
/// resampling descriptor
resampling_d = dnnl_query_resampling_d,
/// source memory desc
src_md = dnnl_query_src_md,
/// source gradient (diff) memory desc
diff_src_md = dnnl_query_diff_src_md,
/// weights memory descriptor desc
weights_md = dnnl_query_weights_md,
/// weights gradient (diff) memory desc
diff_weights_md = dnnl_query_diff_weights_md,
/// destination memory desc
dst_md = dnnl_query_dst_md,
/// destination gradient (diff) memory desc
diff_dst_md = dnnl_query_diff_dst_md,
/// workspace memory desc
workspace_md = dnnl_query_workspace_md,
/// scratchpad memory desc
scratchpad_md = dnnl_query_scratchpad_md,
/// memory desc of an execute argument
exec_arg_md = dnnl_query_exec_arg_md,
};
/// Converts query enum value from C++ API to C API type.
/// @param query C++ API query enum value.
/// @returns Corresponding C API query enum value.
inline dnnl_query_t convert_to_c(query query) {
return static_cast<dnnl_query_t>(query);
}
/// @} dnnl_api_primitives_common
/// @} dnnl_api_primitives
/// @addtogroup dnnl_api_engine Engine
///
/// An abstraction of a computational device: a CPU, a specific GPU
/// card in the system, etc. Most primitives are created to execute
/// computations on one specific engine. The only exceptions are reorder
/// primitives that transfer data between two different engines.
///
/// @sa @ref dev_guide_basic_concepts
///
/// @{
/// @cond DO_NOT_DOCUMENT_THIS
template <>
struct handle_traits<dnnl_engine_t> {
static dnnl_status_t destructor(dnnl_engine_t p) {
return dnnl_engine_destroy(p);
}
};
/// @endcond
/// An execution engine.
struct engine : public handle<dnnl_engine_t> {
friend struct primitive;
friend struct reorder;
/// Kinds of engines.
enum class kind {
/// An unspecified engine
any = dnnl_any_engine,
/// CPU engine
cpu = dnnl_cpu,
/// GPU engine
gpu = dnnl_gpu,
};
using handle::handle;
/// Constructs an empty engine. An empty engine cannot be used in any
/// operations.
engine() = default;
/// Returns the number of engines of a certain kind.
///
/// @param kind The kind of engines to count.
/// @returns The number of engines of the specified kind.
static size_t get_count(kind kind) {
return dnnl_engine_get_count(convert_to_c(kind));
}
/// Constructs an engine.
///
/// @param kind The kind of engine to construct.
/// @param index The index of the engine. Must be less than the value
/// returned by #get_count() for this particular kind of engine.
engine(kind kind, size_t index) {
dnnl_engine_t engine;
error::wrap_c_api(
dnnl_engine_create(&engine, convert_to_c(kind), index),
"could not create an engine");
reset(engine);
}
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
/// Constructs an engine from OpenCL device and context objects.
///
/// @param kind The kind of engine to construct.
/// @param device The OpenCL device that this engine will encapsulate.
/// @param context The OpenCL context (containing the device) that this
/// engine will use for all operations.
engine(kind kind, cl_device_id device, cl_context context) {
dnnl_engine_t engine;
error::wrap_c_api(dnnl_engine_create_ocl(
&engine, convert_to_c(kind), device, context),
"could not create an engine");
reset(engine);
}
#endif
/// Constructs an engine based on a primitive from the primitive
/// descriptor @p pd by querying its engine.
///
/// @param pd The primitive descriptor to query.
engine(const handle<dnnl_primitive_desc_t> &pd) {
dnnl_engine_t c_engine;
error::wrap_c_api(
dnnl_primitive_desc_query(pd.get(),
dnnl::convert_to_c(dnnl::query::engine), 0, &c_engine),
"could not get an engine from a primitive_desc");
reset(c_engine, true);
}
/// Returns the kind of the engine.
/// @returns The kind of the engine.
kind get_kind() const {
dnnl_engine_kind_t kind;
error::wrap_c_api(dnnl_engine_get_kind(get(), &kind),
"could not get kind of an engine");
return static_cast<engine::kind>(kind);
}
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
/// Returns the OpenCL context associated with the engine.
/// @returns OpenCL context.
cl_context get_ocl_context() const {
cl_context context = nullptr;
error::wrap_c_api(dnnl_engine_get_ocl_context(get(), &context),
"could not get an OpenCL context fron an engine");
return context;
}
/// Returns the OpenCL device associated with the engine.
/// @returns OpenCL device.
cl_device_id get_ocl_device() const {
cl_device_id device = nullptr;
error::wrap_c_api(dnnl_engine_get_ocl_device(get(), &device),
"could not get an OpenCL device fron an engine");
return device;
}
#endif
/// Returns the engine of a primitive descriptor.
///
/// @param pd The primitive descriptor to query.
/// @returns A weak handle to the engine that the primitive descriptor was
/// created with.
template <typename primitive_desc>
static engine query(const primitive_desc &pd) {
return query(pd, dnnl::query::engine);
}
private:
static dnnl_engine_kind_t convert_to_c(kind kind) {
return static_cast<dnnl_engine_kind_t>(kind);
}
template <typename primitive_desc>
static engine query(const primitive_desc &pd, dnnl::query what) {
dnnl_engine_t c_engine;
error::wrap_c_api(dnnl_primitive_desc_query(pd.get(),
dnnl::convert_to_c(what), 0, &c_engine),
"could not get an engine from a primitive_desc");
return engine(c_engine, true);
}
};
/// Converts engine kind enum value from C++ API to C API type.
///
/// @param kind C++ API engine kind enum value.
/// @returns Corresponding C API engine kind enum value.
inline dnnl_engine_kind_t convert_to_c(engine::kind kind) {
return static_cast<dnnl_engine_kind_t>(kind);
}
/// @} dnnl_api_engine
/// @addtogroup dnnl_api_stream Stream
///
/// An encapsulation of execution context tied to a particular engine.
///
/// @sa @ref dev_guide_basic_concepts
///
/// @{
/// @cond DO_NOT_DOCUMENT_THIS
template <>
struct handle_traits<dnnl_stream_t> {
static dnnl_status_t destructor(dnnl_stream_t p) {
return dnnl_stream_destroy(p);
}
};
template <>
struct handle_traits<dnnl_stream_attr_t> {
static dnnl_status_t destructor(dnnl_stream_attr_t p) {
return dnnl_stream_attr_destroy(p);
}
};
/// @endcond
/// A container for stream attributes.
struct stream_attr : public handle<dnnl_stream_attr_t> {
using handle::handle;
/// Constructs default (empty) stream attributes.
stream_attr() = default;
/// Constructs stream attributes for a stream that runs on an engine of a
/// particular kind.
///
/// @param kind Target engine kind.
stream_attr(engine::kind kind) {
dnnl_stream_attr_t attr;
error::wrap_c_api(dnnl_stream_attr_create(&attr, convert_to_c(kind)),
"could not create stream attributes");
reset(attr);
}
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
/// Sets the threadpool attribute. Always throws unless oneDNN is built with
/// threadpool runtime.
///
/// @sa @ref dev_guide_threadpool
///
/// @param threadpool A pointer to an instance of a class that implements
/// the dnnl::threadpool_iface interface.
void set_threadpool(threadpool_iface *threadpool) {
error::wrap_c_api(dnnl_stream_attr_set_threadpool(get(), threadpool),
"could not set stream threadpool attribute");
}
/// Returns the threadpool attribute. Always throws unless oneDNN is built
/// with threadpool runtime.
///
/// @sa @ref dev_guide_threadpool
///
threadpool_iface *get_threadpool() {
threadpool_iface *tp;
error::wrap_c_api(dnnl_stream_attr_get_threadpool(get(), (void **)&tp),
"could not set stream threadpool attribute");
return tp;
}
#endif
};
/// An execution stream.
struct stream : public handle<dnnl_stream_t> {
using handle::handle;
/// Stream flags. Can be combined using the bitwise OR operator.
enum class flags : unsigned {
/// Default order execution. Either in-order or out-of-order depending
/// on the engine runtime.
default_order = dnnl_stream_default_order,
/// In-order execution.
in_order = dnnl_stream_default_order,
/// Out-of-order execution.
out_of_order = dnnl_stream_out_of_order,
/// Default stream configuration.
default_flags = dnnl_stream_default_flags,
};
/// Constructs an empty stream. An empty stream cannot be used in any
/// operations.
stream() = default;
/// Constructs a stream for the specified engine and with behavior
/// controlled by the specified flags.
///
/// @param engine Engine to create the stream on.
/// @param flags Flags controlling stream behavior.
/// @param attr Stream attributes.
stream(const engine &engine, flags flags = flags::default_flags,
const stream_attr &attr = stream_attr()) {
dnnl_stream_t stream;
error::wrap_c_api(dnnl_stream_create_v2(&stream, engine.get(),
static_cast<dnnl_stream_flags_t>(flags),
attr.get(true)),
"could not create a stream");
reset(stream);
}
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
/// Constructs a stream for the specified engine and the OpenCL queue.
///
/// @param engine Engine to create the stream on.
/// @param queue OpenCL queue to use for the stream.
stream(const engine &engine, cl_command_queue queue) {
dnnl_stream_t stream;
error::wrap_c_api(dnnl_stream_create_ocl(&stream, engine.get(), queue),
"could not create a stream");
reset(stream);
}
/// Returns the underlying OpenCL queue object.
/// @returns OpenCL queue.
cl_command_queue get_ocl_command_queue() const {
cl_command_queue queue = nullptr;
error::wrap_c_api(dnnl_stream_get_ocl_command_queue(get(), &queue),
"could not get an OpenCL command queue from a stream");
return queue;
}
#endif
/// Waits for all primitives executing in the stream to finish.
/// @returns The stream itself.
stream &wait() {
error::wrap_c_api(
dnnl_stream_wait(get()), "could not wait on a stream");
return *this;
}
};
DNNL_DEFINE_BITMASK_OPS(stream::flags)
/// @} dnnl_api_stream
/// @addtogroup dnnl_api_memory Memory
///
/// A container that describes and stores data. Memory objects can contain
/// data of various types and formats. There are two levels of abstraction:
///
/// 1. **Memory descriptor** -- engine-agnostic logical description of data
/// (number of dimensions, dimension sizes, and data type), and,
/// optionally, the information about the physical format of data in
/// memory. If this information is not known yet, a memory descriptor can
/// be created with #dnnl::memory::format_tag::any. This allows
/// compute-intensive primitives to choose the best format for
/// computation. The user is responsible for reordering the data into the
/// chosen format when formats do not match.
///
/// A memory descriptor can be initialized either by specifying dimensions
/// and a memory format tag or strides for each of them, or by
/// manipulating the dnnl_memory_desc_t structure directly.
///
/// @warning
/// The latter approach requires understanding how the physical data
/// representation is mapped to the structure and is discouraged. This
/// topic is discussed in @ref dev_guide_understanding_memory_formats.
///
/// The user can query the amount of memory required by a memory
/// descriptor using the #dnnl::memory::desc::get_size() function. The
/// size of data in general cannot be computed as the product of
/// dimensions multiplied by the size of the data type. So users are
/// required to use this function for better code portability.
///
/// Two memory descriptors can be compared using the equality and
/// inequality operators. The comparison is especially useful when
/// checking whether it is necessary to reorder data from the user's data
/// format to a primitive's format.
///
/// 2. **Memory object** -- an engine-specific object that handles the data
/// and its description (a memory descriptor). For the CPU engine, the
/// data handle is simply a pointer to @c void. The data handle can be
/// queried using #dnnl::memory::get_data_handle() and set using
/// #dnnl::memory::set_data_handle(). A memory object can also be
/// queried for the underlying memory descriptor and for its engine using
/// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
///
/// Along with ordinary memory descriptors with all dimensions being positive,
/// the library supports *zero-volume* memory descriptors with one or more
/// dimensions set to zero. This is used to support the NumPy\* convention.
/// If a zero-volume memory is passed to a primitive, the primitive typically
/// does not perform any computations with this memory. For example:
///
/// - A concatenation primitive would ignore all memory object with zeroes in
/// the concat dimension / axis.
///
/// - A forward convolution with a source memory object with zero in the
/// minibatch dimension would always produce a destination memory object
/// with a zero in the minibatch dimension and perform no computations.
///
/// - However, a forward convolution with a zero in one of the weights
/// dimensions is ill-defined and is considered to be an error by the
/// library because there is no clear definition of what the output values
/// should be.
///
/// Data handle of a zero-volume memory is never accessed.
///
/// @{
/// Memory object.
///
/// A memory object encapsulates a handle to a memory buffer allocated on a
/// specific engine, tensor dimensions, data type, and memory format, which is
/// the way tensor indices map to offsets in linear memory space. Memory
/// objects are passed to primitives during execution.
struct memory : public handle<dnnl_memory_t> {
/// Integer type for representing dimension sizes and indices.
typedef dnnl_dim_t dim;
/// Vector of dimensions. Implementations are free to force a limit on the
/// vector's length.
typedef std::vector<dim> dims;
/// Helper function that validates that an `std::vector` of dimensions can
/// be safely converted to the C API array ::dnnl_dims_t. Throws if
/// validation fails.
///
/// @param v Vector of dimensions.
/// @param min_size Minimum expected size of the vector.
template <typename T>
static void validate_dims(const std::vector<T> &v, int min_size = 0) {
validate_container_size(
v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
}
/// Data type specification.
enum class data_type {
/// Undefined data type (used for empty memory descriptors).
undef = dnnl_data_type_undef,
/// 16-bit/half-precision floating point.
f16 = dnnl_f16,
/// non-standard 16-bit floating point with 7-bit mantissa.
bf16 = dnnl_bf16,
/// 32-bit/single-precision floating point.
f32 = dnnl_f32,
/// 32-bit signed integer.
s32 = dnnl_s32,
/// 8-bit signed integer.
s8 = dnnl_s8,
/// 8-bit unsigned integer.
u8 = dnnl_u8,
};
/// Memory format kind
enum class format_kind {
/// Undefined memory format kind, used for empty memory descriptors.
undef = dnnl_format_kind_undef,
/// Unspecified format kind.
/// The primitive selects a format automatically.
any = dnnl_format_kind_any,
/// A tensor in a generic format described by the stride and blocking
/// values in each dimension. See @ref dnnl_blocking_desc_t for more
/// information.
blocked = dnnl_blocked,
/// Weights format used in 8bit Winograd convolution.
wino = dnnl_format_kind_wino,
/// Packed weights format used in RNN.
packed = dnnl_format_kind_rnn_packed,
};
/// Memory format tag specification.
///
/// Memory format tags can be further divided into two categories:
///
/// - Domain-agnostic names, i.e. names that do not depend on the tensor
/// usage in the specific primitive. These names use letters from `a`
/// to `l` to denote logical dimensions and form the order in which the
/// dimensions are laid in memory. For example,
/// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
/// second logical dimension (denoted as `b`) is the innermost, i.e.
/// has stride = 1, and the first logical dimension (`a`) is laid out in
/// memory with stride equal to the size of the second dimension. On the
/// other hand, #dnnl::memory::format_tag::ba is the transposed version
/// of the same tensor: the outermost dimension (`a`) becomes the
/// innermost one.
///
/// - Domain-specific names, i.e. names that make sense only in the
/// context of a certain domain, such as CNN. These names are
/// aliases to the corresponding domain-agnostic tags and used mostly
/// for convenience. For example, #dnnl::memory::format_tag::nc
/// is used to denote 2D CNN activations tensor memory format, where
/// the channels dimension is the innermost one and the batch dimension
/// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
/// an alias for #dnnl::memory::format_tag::ab, because for
/// CNN primitives the logical dimensions of activations tensors come
/// in order: batch, channels, spatial. In other words, batch
/// corresponds to the first logical dimension (`a`), and channels
/// correspond to the second one (`b`).
///
/// The following domain-specific notation applies to memory format tags:
/// - @c 'n' denotes the mini-batch dimension
/// - @c 'c' denotes a channels dimension
/// - When there are multiple channel dimensions (for example,
/// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
/// of input and output channels
/// - @c 'g' denotes a groups dimension for convolution weights
/// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
/// respectively
///
/// See @ref dnnl_format_tag_t for a detailed description.
enum class format_tag {
/// Undefined memory format tag
undef = dnnl_format_tag_undef,
/// Placeholder memory format tag. Used to instruct the primitive to
/// select a format automatically.
any = dnnl_format_tag_any,
/// plain 1D tensor
a = dnnl_a,
/// plain 2D tensor
ab = dnnl_ab,
/// permuted 2D tensor
ba = dnnl_ba,
/// plain 3D tensor
abc = dnnl_abc,
/// permuted 3D tensor
acb = dnnl_acb,
/// permuted 3D tensor
bac = dnnl_bac,
/// permuted 3D tensor
bca = dnnl_bca,
/// permuted 3D tensor
cba = dnnl_cba,
/// plain 4D tensor
abcd = dnnl_abcd,
/// permuted 4D tensor
abdc = dnnl_abdc,
/// permuted 4D tensor
acdb = dnnl_acdb,
/// permuted 4D tensor
bacd = dnnl_bacd,
/// permuted 4D tensor
bcda = dnnl_bcda,
/// permuted 4D tensor
cdba = dnnl_cdba,
/// permuted 4D tensor
dcab = dnnl_dcab,
/// plain 5D tensor
abcde = dnnl_abcde,
/// permuted 5D tensor
abdec = dnnl_abdec,
/// permuted 5D tensor
acbde = dnnl_acbde,
/// permuted 5D tensor
acdeb = dnnl_acdeb,
/// permuted 5D tensor
bcdea = dnnl_bcdea,
/// permuted 5D tensor
cdeba = dnnl_cdeba,
/// permuted 5D tensor
decab = dnnl_decab,
/// plain 6D tensor
abcdef = dnnl_abcdef,
/// plain 6D tensor
acbdef = dnnl_acbdef,
/// plain 6D tensor
defcab = dnnl_defcab,
/// 1D tensor; an alias for #dnnl::memory::format_tag::a
x = a,
/// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
nc = ab,
/// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
cn = ba,
/// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
tn = ab,
/// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
nt = ba,
/// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
ncw = abc,
/// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
nwc = acb,
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
nchw = abcd,
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
nhwc = acdb,
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
chwn = bcda,
/// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
ncdhw = abcde,
/// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
ndhwc = acdeb,
/// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
oi = ab,
/// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
io = ba,
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
oiw = abc,
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
owi = acb,
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
wio = cba,
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
iwo = bca,
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
oihw = abcd,
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
hwio = cdba,
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
ohwi = acdb,
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
ihwo = bcda,
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
iohw = bacd,
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
oidhw = abcde,
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
dhwio = cdeba,
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
odhwi = acdeb,
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
idhwo = bcdea,
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
goiw = abcd,
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
wigo = dcab,
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
goihw = abcde,
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
hwigo = decab,
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
giohw = acbde,
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
goidhw = abcdef,
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
giodhw = acbdef,
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
dhwigo = defcab,
/// 3D RNN data tensor in the format (seq_length, batch, input channels).
tnc = abc,
/// 3D RNN data tensor in the format (batch, seq_length, input channels).
ntc = bac,
/// 4D RNN states tensor in the format (num_layers, num_directions,
/// batch, state channels).
ldnc = abcd,
/// 5D RNN weights tensor in the format (num_layers, num_directions,
/// input_channels, num_gates, output_channels).
///
/// - For LSTM cells, the gates order is input, forget, candidate
/// and output gate.
/// - For GRU cells, the gates order is update, reset and output gate.
ldigo = abcde,
/// 5D RNN weights tensor in the format (num_layers, num_directions,
/// num_gates, output_channels, input_channels).
///
/// - For LSTM cells, the gates order is input, forget, candidate
/// and output gate.
/// - For GRU cells, the gates order is update, reset and output gate.
ldgoi = abdec,
/// 4D LSTM projection tensor in the format (num_layers, num_directions,
/// num_channels_in_hidden_state, num_channels_in_recurrent_projection).
ldio = abcd,
/// 4D LSTM projection tensor in the format (num_layers, num_directions,
/// num_channels_in_recurrent_projection, num_channels_in_hidden_state).
ldoi = abdc,
/// 4D RNN bias tensor in the format (num_layers, num_directions,
/// num_gates, output_channels).
///
/// - For LSTM cells, the gates order is input, forget, candidate
/// and output gate.
/// - For GRU cells, the gates order is update, reset and output gate.
ldgo = abcd,
// Opaque blocked formats
Abc16a = dnnl_Abc16a,
ABc16a16b = dnnl_ABc16a16b,
ABc4a4b = dnnl_ABc4a4b,
aBc16b = dnnl_aBc16b,
ABc16b16a = dnnl_ABc16b16a,
Abc4a = dnnl_Abc4a,
aBc4b = dnnl_aBc4b,
ABc4b16a4b = dnnl_ABc4b16a4b,
ABc2b8a4b = dnnl_ABc2b8a4b,
ABc4b4a = dnnl_ABc4b4a,
ABc8a16b2a = dnnl_ABc8a16b2a,
ABc8a8b = dnnl_ABc8a8b,
aBc8b = dnnl_aBc8b,
ABc8b16a2b = dnnl_ABc8b16a2b,
ABc8b8a = dnnl_ABc8b8a,
Abcd16a = dnnl_Abcd16a,
ABcd16a16b = dnnl_ABcd16a16b,
aBcd16b = dnnl_aBcd16b,
ABcd16b16a = dnnl_ABcd16b16a,
aBCd16b16c = dnnl_aBCd16b16c,
aBCd16c16b = dnnl_aBCd16c16b,
Abcd4a = dnnl_Abcd4a,
aBcd4b = dnnl_aBcd4b,
ABcd4b16a4b = dnnl_ABcd4b16a4b,
ABcd2b8a4b = dnnl_ABcd2b8a4b,
ABcd4b4a = dnnl_ABcd4b4a,
ABcd4a4b = dnnl_ABcd4a4b,
aBCd4c16b4c = dnnl_aBCd4c16b4c,
aBCd2c8b4c = dnnl_aBCd2c8b4c,
aBCd4c4b = dnnl_aBCd4c4b,
aBCd4b4c = dnnl_aBCd4b4c,
ABcd8a16b2a = dnnl_ABcd8a16b2a,
ABcd8a8b = dnnl_ABcd8a8b,
/// 4D tensor blocked by 2nd dimension with block size 8
aBcd8b = dnnl_aBcd8b,
ABcd8b16a2b = dnnl_ABcd8b16a2b,
aBCd8b16c2b = dnnl_aBCd8b16c2b,
/// 4D tensor blocked by 1st and 2nd dimension with block size 8
ABcd8b8a = dnnl_ABcd8b8a,
aBCd8b8c = dnnl_aBCd8b8c,
aBCd8c16b2c = dnnl_aBCd8c16b2c,
aBCd8c8b = dnnl_aBCd8c8b,
Abcde16a = dnnl_Abcde16a,
ABcde16a16b = dnnl_ABcde16a16b,
aBcde16b = dnnl_aBcde16b,
ABcde16b16a = dnnl_ABcde16b16a,
aBCde16b16c = dnnl_aBCde16b16c,
aBCde16c16b = dnnl_aBCde16c16b,
aBCde2c8b4c = dnnl_aBCde2c8b4c,
Abcde4a = dnnl_Abcde4a,
aBcde4b = dnnl_aBcde4b,
ABcde4b4a = dnnl_ABcde4b4a,
ABcde4a4b = dnnl_ABcde4a4b,
aBCde4b4c = dnnl_aBCde4b4c,
aBCde4c16b4c = dnnl_aBCde4c16b4c,
aBCde4c4b = dnnl_aBCde4c4b,
Abcde8a = dnnl_Abcde8a,
ABcde8a8b = dnnl_ABcde8a8b,
aBcde8b = dnnl_aBcde8b,
ABcde8b16a2b = dnnl_ABcde8b16a2b,
ABcde4b16a4b = dnnl_ABcde4b16a4b,
ABcde2b8a4b = dnnl_ABcde2b8a4b,
aBCde8b16c2b = dnnl_aBCde8b16c2b,
ABcde8b8a = dnnl_ABcde8b8a,
aBCde8b8c = dnnl_aBCde8b8c,
ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
aBCde8c16b2c = dnnl_aBCde8c16b2c,
aBCde8c8b = dnnl_aBCde8c8b,
aBcdef16b = dnnl_aBcdef16b,
aBCdef16b16c = dnnl_aBCdef16b16c,
aBCdef16c16b = dnnl_aBCdef16c16b,
aBcdef4b = dnnl_aBcdef4b,
aBCdef4c4b = dnnl_aBCdef4c4b,
aBCdef4b4c = dnnl_aBCdef4b4c,
aBCdef8b8c = dnnl_aBCdef8b8c,
aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
aBCdef8c8b = dnnl_aBCdef8c8b,
aBdc16b = dnnl_aBdc16b,
aBdc4b = dnnl_aBdc4b,
aBdc8b = dnnl_aBdc8b,
aBdec16b = dnnl_aBdec16b,
aBdec4b = dnnl_aBdec4b,
aBdec8b = dnnl_aBdec8b,
aBdefc16b = dnnl_aBdefc16b,
aCBdef16c16b = dnnl_aCBdef16c16b,
aCBdef16b16c = dnnl_aCBdef16b16c,
aBdefc4b = dnnl_aBdefc4b,
aBdefc8b = dnnl_aBdefc8b,
Acb16a = dnnl_Acb16a,
Acb4a = dnnl_Acb4a,
Acb8a = dnnl_Acb8a,
aCBd16b16c = dnnl_aCBd16b16c,
aCBd16c16b = dnnl_aCBd16c16b,
aCBde16b16c = dnnl_aCBde16b16c,
aCBde16c16b = dnnl_aCBde16c16b,
Acdb16a = dnnl_Acdb16a,
Acdb4a = dnnl_Acdb4a,
Acdb8a = dnnl_Acdb8a,
Acdeb16a = dnnl_Acdeb16a,
Acdeb4a = dnnl_Acdeb4a,
Acdeb8a = dnnl_Acdeb8a,
BAc16a16b = dnnl_BAc16a16b,
BAc16b16a = dnnl_BAc16b16a,
BAcd16a16b = dnnl_BAcd16a16b,
BAcd16b16a = dnnl_BAcd16b16a,
ABcd32a32b = dnnl_ABcd32a32b,
BAcde16b16a = dnnl_BAcde16b16a,
BAcde16a16b = dnnl_BAcde16a16b,
aBdec32b = dnnl_aBdec32b,
Abcdef16a = dnnl_Abcdef16a,
Acdb32a = dnnl_Acdb32a,
aBCd2b4c2b = dnnl_aBCd2b4c2b,
aBCde2b4c2b = dnnl_aBCde2b4c2b,
aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
aBCd2c4b2c = dnnl_aBCd2c4b2c,
aBCde2c4b2c = dnnl_aBCde2c4b2c,
aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
aBCd4b8c2b = dnnl_aBCd4b8c2b,
aBCde4b8c2b = dnnl_aBCde4b8c2b,
aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
aBCd4c8b2c = dnnl_aBCd4c8b2c,
aBCde4c8b2c = dnnl_aBCde4c8b2c,
aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
format_tag_last = dnnl_format_tag_last,
nCdhw16c = dnnl_nCdhw16c,
nCdhw4c = dnnl_nCdhw4c,
nCdhw8c = dnnl_nCdhw8c,
nChw16c = dnnl_nChw16c,
nChw4c = dnnl_nChw4c,
nChw8c = dnnl_nChw8c,
nCw16c = dnnl_nCw16c,
nCw4c = dnnl_nCw4c,
nCw8c = dnnl_nCw8c,
NCw16n16c = dnnl_NCw16n16c,
NChw16n16c = dnnl_NChw16n16c,
NCdhw16n16c = dnnl_NCdhw16n16c,
NChw32n32c = dnnl_NChw32n32c,
IOhw16i16o = dnnl_IOhw16i16o,
Ohwi32o = dnnl_Ohwi32o,
IOdhw16i16o = dnnl_IOdhw16i16o,
gIOhw16i16o = dnnl_gIOhw16i16o,
gOhwi32o = dnnl_gOhwi32o,
Goidhw16g = dnnl_Goidhw16g,
IOw16o16i = dnnl_IOw16o16i,
OIw16i16o = dnnl_OIw16i16o,
IOw16i16o = dnnl_IOw16i16o,
gIOw16i16o = dnnl_gIOw16i16o,
OIw16o16i = dnnl_OIw16o16i,
Oiw16o = dnnl_Oiw16o,
OIw4i16o4i = dnnl_OIw4i16o4i,
OIw2i8o4i = dnnl_OIw2i8o4i,
OIw4i4o = dnnl_OIw4i4o,
OIw4o4i = dnnl_OIw4o4i,
Oiw4o = dnnl_Oiw4o,
OIw8i16o2i = dnnl_OIw8i16o2i,
OIw8i8o = dnnl_OIw8i8o,
OIw8o16i2o = dnnl_OIw8o16i2o,
OIw8o8i = dnnl_OIw8o8i,
Owi16o = dnnl_Owi16o,
OwI16o2i = dnnl_OwI16o2i,
Owi4o = dnnl_Owi4o,
Owi8o = dnnl_Owi8o,
IOhw16o16i = dnnl_IOhw16o16i,
Ohwi16o = dnnl_Ohwi16o,
OhwI16o2i = dnnl_OhwI16o2i,
Ohwi4o = dnnl_Ohwi4o,
Ohwi8o = dnnl_Ohwi8o,
OIhw16i16o = dnnl_OIhw16i16o,
OIhw16o16i = dnnl_OIhw16o16i,
Oihw16o = dnnl_Oihw16o,
OIhw4i16o4i = dnnl_OIhw4i16o4i,
OIhw4i4o = dnnl_OIhw4i4o,
OIhw4o4i = dnnl_OIhw4o4i,
Oihw4o = dnnl_Oihw4o,
OIhw8i16o2i = dnnl_OIhw8i16o2i,
OIhw8i8o = dnnl_OIhw8i8o,
OIhw8o16i2o = dnnl_OIhw8o16i2o,
OIhw8o8i = dnnl_OIhw8o8i,
OIhw2i8o4i = dnnl_OIhw2i8o4i,
IOdhw16o16i = dnnl_IOdhw16o16i,
Odhwi16o = dnnl_Odhwi16o,
OdhwI16o2i = dnnl_OdhwI16o2i,
Odhwi4o = dnnl_Odhwi4o,
Odhwi8o = dnnl_Odhwi8o,
OIdhw16i16o = dnnl_OIdhw16i16o,
OIdhw16o16i = dnnl_OIdhw16o16i,
Oidhw16o = dnnl_Oidhw16o,
OIdhw4i4o = dnnl_OIdhw4i4o,
OIdhw4o4i = dnnl_OIdhw4o4i,
Oidhw4o = dnnl_Oidhw4o,
OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
OIdhw8i8o = dnnl_OIdhw8i8o,
OIdhw8o8i = dnnl_OIdhw8o8i,
gIOw16o16i = dnnl_gIOw16o16i,
gOIw16i16o = dnnl_gOIw16i16o,
gOIw16o16i = dnnl_gOIw16o16i,
gOiw16o = dnnl_gOiw16o,
gOIw4i16o4i = dnnl_gOIw4i16o4i,
gOIw2i8o4i = dnnl_gOIw2i8o4i,
gOIw4i4o = dnnl_gOIw4i4o,
gOIw4o4i = dnnl_gOIw4o4i,
gOiw4o = dnnl_gOiw4o,
gOIw8i16o2i = dnnl_gOIw8i16o2i,
gOIw8i8o = dnnl_gOIw8i8o,
gOIw8o16i2o = dnnl_gOIw8o16i2o,
gOIw8o8i = dnnl_gOIw8o8i,
gOwi16o = dnnl_gOwi16o,
gOwI16o2i = dnnl_gOwI16o2i,
gOwi4o = dnnl_gOwi4o,
gOwi8o = dnnl_gOwi8o,
Goiw8g = dnnl_Goiw8g,
Goiw16g = dnnl_Goiw16g,
gIOhw16o16i = dnnl_gIOhw16o16i,
gOhwi16o = dnnl_gOhwi16o,
gOhwI16o2i = dnnl_gOhwI16o2i,
gOhwi4o = dnnl_gOhwi4o,
gOhwi8o = dnnl_gOhwi8o,
Goihw16g = dnnl_Goihw16g,
gOIhw16i16o = dnnl_gOIhw16i16o,
gOIhw16o16i = dnnl_gOIhw16o16i,
gOihw16o = dnnl_gOihw16o,
gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
gOIhw4i4o = dnnl_gOIhw4i4o,
gOIhw4o4i = dnnl_gOIhw4o4i,
gOihw4o = dnnl_gOihw4o,
Goihw8g = dnnl_Goihw8g,
gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
gOIhw8i8o = dnnl_gOIhw8i8o,
gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
gOIhw8o8i = dnnl_gOIhw8o8i,
gIOdhw16i16o = dnnl_gIOdhw16i16o,
gIOdhw16o16i = dnnl_gIOdhw16o16i,
gOdhwi16o = dnnl_gOdhwi16o,
gOdhwI16o2i = dnnl_gOdhwI16o2i,
gOdhwi4o = dnnl_gOdhwi4o,
gOdhwi8o = dnnl_gOdhwi8o,
gOIdhw16i16o = dnnl_gOIdhw16i16o,
gOIdhw16o16i = dnnl_gOIdhw16o16i,
gOidhw16o = dnnl_gOidhw16o,
gOIdhw4i4o = dnnl_gOIdhw4i4o,
gOIdhw4o4i = dnnl_gOIdhw4o4i,
gOidhw4o = dnnl_gOidhw4o,
gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
gOIdhw8i8o = dnnl_gOIdhw8i8o,
gOIdhw8o8i = dnnl_gOIdhw8o8i,
gOIw2i4o2i = dnnl_gOIw2i4o2i,
gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
gOIw2o4i2o = dnnl_gOIw2o4i2o,
gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
gOIw4i8o2i = dnnl_gOIw4i8o2i,
gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
gOIw4o8i2o = dnnl_gOIw4o8i2o,
gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
};
/// A memory descriptor.
struct desc {
friend struct memory;
/// The underlying C API data structure.
dnnl_memory_desc_t data;
/// Constructs a zero (empty) memory descriptor. Such a memory
/// descriptor can be used to indicate absence of an argument.
desc() : data() {}
/// Constructs a memory descriptor.
///
/// @note
/// The logical order of dimensions corresponds to the `abc...`
/// format tag, and the physical meaning of the dimensions depends
/// both on the primitive that would operate on this memory and
/// the operation context.
///
/// @param dims Tensor dimensions.
/// @param data_type Data precision/type.
/// @param format_tag Memory format tag.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case a
/// zero memory descriptor will be constructed. This flag is
/// optional and defaults to false.
desc(const memory::dims &dims, data_type data_type,
format_tag format_tag, bool allow_empty = false)
: data() {
validate_dims(dims);
dnnl_status_t status = dnnl_memory_desc_init_by_tag(&data,
(int)dims.size(), dims.data(), convert_to_c(data_type),
convert_to_c(format_tag));
if (!allow_empty)
error::wrap_c_api(status,
"could not construct a memory descriptor using a "
"format tag");
}
/// Constructs a memory descriptor by strides.
///
/// @note
/// The logical order of dimensions corresponds to the `abc...`
/// format tag, and the physical meaning of the dimensions depends
/// both on the primitive that would operate on this memory and
/// the operation context.
///
/// @param dims Tensor dimensions.
/// @param data_type Data precision/type.
/// @param strides Strides for each dimension.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case a
/// zero memory descriptor will be constructed. This flag is
/// optional and defaults to false.
desc(const memory::dims &dims, data_type data_type,
const memory::dims &strides, bool allow_empty = false)
: data() {
validate_dims(dims);
if (!strides.empty()) validate_dims(strides, (int)dims.size());
dnnl_status_t status = dnnl_memory_desc_init_by_strides(&data,
(int)dims.size(), dims.data(), convert_to_c(data_type),
strides.empty() ? nullptr : &strides[0]);
if (!allow_empty)
error::wrap_c_api(status,
"could not construct a memory descriptor using "
"strides");
}
/// Constructs a memory descriptor from a C API data structure.
///
/// @param data A C API ::dnnl_memory_desc_t structure.
desc(const dnnl_memory_desc_t &data) : data(data) {}
/// Constructs a memory descriptor for a region inside an area
/// described by this memory descriptor.
//
/// @param dims Sizes of the region.
/// @param offsets Offsets to the region from the encompassing
/// memory object in each dimension.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case a
/// zero memory descriptor will be returned. This flag is optional
/// and defaults to false.
/// @returns A memory descriptor for the region.
desc submemory_desc(const memory::dims &dims,
const memory::dims &offsets, bool allow_empty = false) const {
validate_dims(dims, data.ndims);
validate_dims(offsets, data.ndims);
dnnl_memory_desc_t sub_md = dnnl_memory_desc_t();
dnnl_status_t status = dnnl_memory_desc_init_submemory(
&sub_md, &data, dims.data(), offsets.data());
if (!allow_empty)
error::wrap_c_api(status, "could not construct a sub-memory");
return desc(sub_md);
}
/// Constructs a memory descriptor by reshaping an existing one. The
/// new memory descriptor inherits the data type. This operation is
/// valid only for memory descriptors that have format_kind set to
/// #dnnl::memory::format_kind::blocked or
/// #dnnl::memory::format_kind::any.
///
/// The operation ensures that the transformation of the physical memory
/// format corresponds to the transformation of the logical dimensions.
/// If such transformation is impossible, the function either throws an
/// exception (default) or returns a zero memory descriptor depending on
/// the `allow_empty` flag.
///
/// The reshape operation can be described as a combination of the
/// following basic operations:
/// 1. Add a dimension of size `1`. This is always possible.
/// 2. Remove a dimension of size `1`. This is possible only if the
/// dimension has no padding (i.e.
/// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
/// 3. Split a dimension into multiple ones. This is possible only if
/// the size of the dimension is exactly equal to the product of the
/// split ones and the dimension does not have padding (i.e.
/// `padded_dims[dim] = dims[dim]`).
/// 4. Joining multiple consecutive dimensions into a single one. As in
/// the cases above, this requires that the dimensions do not have
/// padding and that the memory format is such that in physical
/// memory these dimensions are dense and have the same order as
/// their logical counterparts. This also assumes that these
/// dimensions are not blocked.
/// - Here, dense means:
/// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
/// - And same order means:
/// `i < j <=> stride for dim[i] < stride for dim[j]`.
///
/// @warning
/// Some combinations of physical memory layout and/or offsets or
/// dimensions may result in a failure to make a reshape.
///
/// @param dims New dimensions. The product of dimensions must
/// remain constant.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case a
/// zero memory descriptor will be returned. This flag is optional
/// and defaults to false.
/// @returns A new memory descriptor with new dimensions.
desc reshape(const memory::dims &dims, bool allow_empty = false) const {
if (data.ndims) validate_dims(dims, 1);
dnnl_memory_desc_t out_md = dnnl_memory_desc_t();
dnnl_status_t status = dnnl_memory_desc_reshape(
&out_md, &data, (int)dims.size(), dims.data());
if (!allow_empty)
error::wrap_c_api(
status, "could not reshape a memory descriptor");
return desc(out_md);
}
/// Constructs a memory descriptor by permuting axes in an existing
/// one.
///
/// The physical memory layout representation is adjusted accordingly
/// to maintain the consistency between the logical and physical parts
/// of the memory descriptor.
///
/// The new memory descriptor inherits the data type. This operation is
/// valid only for memory descriptors that have format_kind set to
/// #dnnl::memory::format_kind::blocked or
/// #dnnl::memory::format_kind::any.
///
/// The logical axes will be permuted in the following manner:
/// ```
/// for (i: 0 .. ndims())
/// new_desc.dims()[permutation[i]] = dims()[i];
/// ```
///
/// Example:
/// @code
/// std::vector<int> permutation = {1, 0}; // swap the first and
/// // the second axes
/// dnnl::memory::desc in_md(
/// {2, 3}, data_type, memory::format_tag::ab);
/// dnnl::memory::desc expect_out_md(
/// {3, 2}, data_type, memory::format_tag::ba);
///
/// assert(in_md.permute_axes(permutation) == expect_out_md);
/// @endcode
///
/// @param permutation Axes permutation.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case a
/// zero memory descriptor will be returned. This flag is optional
/// and defaults to false.
/// @returns A new memory descriptor with new dimensions.
desc permute_axes(const std::vector<int> &permutation,
bool allow_empty = false) const {
validate_dims(permutation, data.ndims);
dnnl_memory_desc_t out_md = dnnl_memory_desc_t();
dnnl_status_t status = dnnl_memory_desc_permute_axes(
&out_md, &data, permutation.data());
if (!allow_empty)
error::wrap_c_api(status,
"could not permute axes of a memory descriptor");
return desc(out_md);
}
/// Returns dimensions of the memory descriptor.
///
/// Potentially expensive due to the data copy involved.
/// @returns A copy of the dimensions vector.
memory::dims dims() const {
return memory::dims(data.dims, data.dims + data.ndims);
}
/// Returns the data type of the memory descriptor.
/// @returns The data type.
memory::data_type data_type() const {
return static_cast<memory::data_type>(data.data_type);
}
/// Returns size of the memory descriptor in bytes.
/// @returns The number of bytes required to allocate a memory buffer
/// for the memory object described by this memory descriptor
/// including the padding area.
size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
/// Checks whether the memory descriptor is zero (empty).
/// @returns @c true if the memory descriptor describes an empty
/// memory and @c false otherwise.
bool is_zero() const { return data.ndims == 0; }
/// An equality operator.
/// @param other Another memory descriptor.
/// @returns Whether this and the other memory descriptors have
/// the same format tag, dimensions, strides, blocking, etc.
bool operator==(const desc &other) const {
return dnnl_memory_desc_equal(&data, &other.data) != 0;
}
/// An inequality operator.
/// @param other Another memory descriptor.
/// @returns Whether this and the other memory descriptors describe
/// different memory.
bool operator!=(const desc &other) const { return !operator==(other); }
};
// Default constructor.
//
// Constructs an empty memory object, which can be used to indicate absence
// of a parameter.
memory() = default;
/// Constructs a memory object.
///
/// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
/// object will have the underlying buffer set. In this case, the buffer
/// will be initialized as if #dnnl::memory::set_data_handle() had been
/// called.
///
/// @sa memory::set_data_handle()
///
/// @param md Memory descriptor.
/// @param engine Engine to store the data on.
/// @param handle Handle of the memory buffer to use as an underlying
/// storage.
/// - A pointer to the user-allocated buffer. In this case the library
/// doesn't own the buffer.
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
/// allocate the buffer for the memory object. In this case the
/// library owns the buffer.
/// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying
/// buffer.
memory(const desc &md, const engine &engine, void *handle) {
dnnl_memory_t result;
error::wrap_c_api(
dnnl_memory_create(&result, &md.data, engine.get(), handle),
"could not create a memory object");
reset(result);
}
/// Constructs a memory object.
///
/// The underlying storage for the memory will be allocated by the library.
///
/// @param md Memory descriptor.
/// @param engine Engine to store the data on.
memory(const desc &md, const engine &engine)
: memory(md, engine, DNNL_MEMORY_ALLOCATE) {}
/// Returns the associated memory descriptor.
desc get_desc() const {
const dnnl_memory_desc_t *cdesc;
error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
"could not get a memory descriptor from a memory object");
return desc(*cdesc);
}
/// Returns the associated engine.
engine get_engine() const {
dnnl_engine_t c_engine;
error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
"could not get an engine from a memory object");
return engine(c_engine, true);
}
/// Returns the underlying memory buffer.
///
/// On the CPU engine this is a pointer to the allocated memory.
void *get_data_handle() const {
void *handle;
error::wrap_c_api(dnnl_memory_get_data_handle(get(), &handle),
"could not get a native handle from a memory object");
return handle;
}
/// Sets data handle.
///
/// This function may write zero values to the memory specified by the @p
/// handle if the memory object has a zero padding area. This may be time
/// consuming and happens each time this function is called. The
/// operation is always blocking and the stream parameter is a hint.
///
/// @note
/// The zero padding is required by memory objects created with
/// blocked memory format tags like #dnnl_aBcd8b when any of the
/// dimensions is not a multiple of the corresponding block size. For
/// "plain" formats like #dnnl::memory::format_tag::nchw or
/// #dnnl::memory::format_tag::nhwc zero padding area needs to be set
/// up explicitly when creating the corresponding memory descriptors.
/// See @ref dev_guide_understanding_memory_formats for more details.
///
/// @note
/// Even when the memory object is used to hold values that stay
/// constant during the execution of the program (pre-packed weights
/// during inference, for example), the function will still write
/// zeroes to the padding area if it exists. Hence, the @p handle
/// parameter cannot and does not have a const qualifier.
///
/// @param handle Data handle. For the CPU engine, the data handle
/// is a pointer to the actual data. For OpenCL it is a cl_mem.
/// @param stream Stream to use to execute padding in.
void set_data_handle(void *handle, const stream &stream) const {
error::wrap_c_api(
dnnl_memory_set_data_handle_v2(get(), handle, stream.get(true)),
"could not set native handle of a memory object");
}
/// Sets data handle.
///
/// See documentation for
/// #dnnl::memory::set_data_handle(void *, const stream &) const
/// for more information.
///
/// @param handle Data handle. For the CPU engine, the data handle
/// is a pointer to the actual data. For OpenCL it is a cl_mem.
void set_data_handle(void *handle) const {
error::wrap_c_api(
dnnl_memory_set_data_handle_v2(get(), handle, nullptr),
"could not set native handle of a memory object");
}
/// Maps a memory object and returns a host-side pointer to a memory
/// buffer with a copy of its contents.
///
/// Mapping enables read/write directly from/to the memory contents for
/// engines that do not support direct memory access.
///
/// Mapping is an exclusive operation - a memory object cannot be used in
/// other operations until it is unmapped via memory::unmap_data() call.
///
/// @note
/// Any primitives working with the memory should be completed before
/// the memory is mapped. Use stream::wait() to synchronize the
/// corresponding execution stream.
///
/// @note
/// The map_data and unmap_data functions are provided mainly for
/// debug and testing purposes and their performance may be suboptimal.
///
/// @tparam T Data type to return a pointer to.
/// @returns Pointer to the mapped memory.
template <typename T = void>
T *map_data() const {
void *mapped_ptr;
error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
"could not map memory object data");
return static_cast<T *>(mapped_ptr);
}
/// Unmaps a memory object and writes back any changes made to the
/// previously mapped memory buffer. The pointer to the mapped buffer
/// must be obtained via the dnnl::memory::map_data() call.
///
/// @note
/// The map_data and unmap_data functions are provided mainly for
/// debug and testing purposes and their performance may be suboptimal.
///
/// @param mapped_ptr A pointer previously returned by map_data().
void unmap_data(void *mapped_ptr) const {
error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
"could not unmap memory object data");
}
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
/// Returns the OpenCL memory object associated with the memory.
cl_mem get_ocl_mem_object() const {
cl_mem mem_object;
error::wrap_c_api(dnnl_memory_get_ocl_mem_object(get(), &mem_object),
"could not get OpenCL buffer object from a memory object");
return mem_object;
}
/// Sets the OpenCL memory object @p mem_object associated with the memory.
///
/// For behavioral details see memory::set_data_handle().
///
/// @param mem_object OpenCL cl_mem object to use as the underlying
/// storage. It must have at least get_desc().get_size() bytes
/// allocated.
void set_ocl_mem_object(cl_mem mem_object) {
error::wrap_c_api(dnnl_memory_set_ocl_mem_object(get(), mem_object),
"could not set OpenCL buffer object from a memory object");
}
#endif
static dnnl_data_type_t convert_to_c(data_type data_type) {
return static_cast<dnnl_data_type_t>(data_type);
}
static dnnl_format_tag_t convert_to_c(format_tag format) {
return static_cast<dnnl_format_tag_t>(format);
}
};
inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
return a == memory::convert_to_c(b);
}
inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
return !(a == b);
}
inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
return b == a;
}
inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
return !(a == b);
}
inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
return a == memory::convert_to_c(b);
}
inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
return !(a == b);
}
inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
return b == a;
}
inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
return !(a == b);
}
/// @} dnnl_api_memory
/// @addtogroup dnnl_api_primitives
/// @{
/// @addtogroup dnnl_api_attributes Attributes
///
/// A container for parameters that extend primitives behavior.
///
/// @{
/// @cond DO_NOT_DOCUMENT_THIS
template <>
struct handle_traits<dnnl_post_ops_t> {
static dnnl_status_t destructor(dnnl_post_ops_t p) {
return dnnl_post_ops_destroy(p);
}
};
/// @endcond
/// Post-ops.
///
/// Post-ops are computations executed after the main primitive computations
/// and are attached to the primitive via primitive attributes.
///
/// @sa @ref dev_guide_attributes_post_ops
///
struct post_ops : public handle<dnnl_post_ops_t> {
using handle<dnnl_post_ops_t>::handle;
/// Constructs an empty sequence of post-ops.
post_ops() {
dnnl_post_ops_t result;
error::wrap_c_api(
dnnl_post_ops_create(&result), "could not create post-ops");
reset(result);
}
/// Returns the number of post-ops entries.
int len() const { return dnnl_post_ops_len(get()); }
/// Returns the primitive kind of post-op at entry with a certain index.
/// @param index Index of the post-op to return the kind for.
/// @returns Primitive kind of the post-op at the specified index.
primitive::kind kind(int index) const {
error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
"post-ops index is out of range");
return static_cast<primitive::kind>(
dnnl_post_ops_get_kind(get(), index));
}
/// Appends an accumulation (sum) post-op. Prior to accumulating the
/// result, the previous value would be multiplied by a scaling factor
/// @p scale.
///
/// The kind of this post-op is #dnnl::primitive::kind::sum.
///
/// This feature may improve performance for cases like residual learning
/// blocks, where the result of convolution is accumulated to the
/// previously computed activations. The parameter @p scale may be used
/// for the integer-based computations when the result and previous
/// activations have different logical scaling factors.
///
/// In the simplest case when the accumulation is the only post-op,
/// the computations would be:
///
/// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
///
/// @note
/// This post-op executes in-place and does not change the
/// destination layout.
///
/// @param scale Scaling factor.
void append_sum(float scale = 1.) {
error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale),
"could not append a sum post-op");
}
/// Returns the parameters of an accumulation (sum) post-op.
///
/// @param index Index of the sum post-op.
/// @param scale Scaling factor of the sum post-op.
void get_params_sum(int index, float &scale) const {
error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale),
"could not get parameters of a sum post-op");
}
/// Appends an elementwise post-op.
///
/// The kind of this post-op is #dnnl::primitive::kind::eltwise.
///
/// In the simplest case when the elementwise is the only post-op, the
/// computations would be:
///
/// dst[:] <- scale * eltwise_op (op(...)) // instead of dst[:] <- op(...)
///
/// where eltwise_op is configured with the given parameters.
///
/// @param scale Scaling factor.
/// @param algorithm Elementwise algorithm.
/// @param alpha Alpha parameter for the elementwise algorithm.
/// @param beta Beta parameter for the elementwise algorithm.
void append_eltwise(
float scale, algorithm algorithm, float alpha, float beta) {
error::wrap_c_api(dnnl_post_ops_append_eltwise(get(), scale,
convert_to_c(algorithm), alpha, beta),
"could not append an elementwise post-op");
}
/// Returns parameters of an elementwise post-up.
///
/// @param index Index of the post-op.
/// @param scale Output scaling factor.
/// @param algorithm Output elementwise algorithm kind.
/// @param alpha Output alpha parameter for the elementwise algorithm.
/// @param beta Output beta parameter for the elementwise algorithm.
void get_params_eltwise(int index, float &scale, algorithm &algorithm,
float &alpha, float &beta) const {
dnnl_alg_kind_t c_alg;
error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
get(), index, &scale, &c_alg, &alpha, &beta),
"could not get parameters of an elementwise post-op");
algorithm = static_cast<dnnl::algorithm>(c_alg);
}
/// Appends a depthwise post-op convolution with stride 1.
///
/// This post-op can only be fused with a 2D 1x1 convolution (convolution
/// with weights spatial dimension equal to 1 i.e., kh=kw=1).
///
/// The kind of this post-op is #dnnl_convolution.
///
/// The number of outputs for primitive remain same as before. The output
/// size remain same as the original primitive due to stride=1.
///
/// The Post-op can be defined as:
///
/// dst[:] <- scales * (conv_dw(conv_1x1))
///
/// See @ref dev_guide_attributes_post_ops_depthwise and
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
///
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the
/// @p scales array. The set i-th bit indicates that a dedicated output
/// scaling factor is used for each index along that dimension. The mask
/// value of 0 implies a common scaling factor for the whole output
/// tensor.
/// @param scales Output pointer to a constant array of float scaling
/// factors.
void append_dw_k3s1p1(memory::data_type weights_data_type,
memory::data_type bias_data_type, memory::data_type dst_data_type,
int mask, const std::vector<float> &scales) {
error::wrap_c_api(dnnl_post_ops_append_dw_k3s1p1(get(),
memory::convert_to_c(weights_data_type),
memory::convert_to_c(bias_data_type),
memory::convert_to_c(dst_data_type),
scales.size(), mask, &scales[0]),
"could not append depthwise post-op");
}
/// Returns the parameters of an depthwise post-op with stride 1.
///
/// @param index Index of the elementwise post-op.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the
/// @p scales array. The set i-th bit indicates that a dedicated output
/// scaling factor is used for each index along that dimension. The mask
/// value of 0 implies a common scaling factor for the whole output
/// tensor.
/// @param scales Output pointer to a constant array of float scaling
/// factors.
void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type,
memory::data_type &bias_data_type, memory::data_type &dst_data_type,
int &mask, std::vector<float> &scales) const {
dnnl_data_type_t c_weights_data_type;
dnnl_data_type_t c_bias_data_type;
dnnl_data_type_t c_dst_data_type;
dnnl_dim_t count;
int c_mask;
const float *c_scales;
error::wrap_c_api(dnnl_post_ops_get_params_dw_k3s1p1(get(), index,
&c_weights_data_type, &c_bias_data_type,
&c_dst_data_type, &count, &c_mask, &c_scales),
"could not get parameters of depthwise post-op");
weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
scales.resize(count);
mask = c_mask;
for (dnnl_dim_t c = 0; c < count; ++c)
scales[c] = c_scales[c];
return;
}
/// Appends a depthwise post-op convolution with stride 2.
///
/// This post-op can only be fused with a 2D 1x1 convolution (convolution
/// with weights spatial dimension equal to 1 i.e., kh=kw=1).
///
/// The kind of this post-op is #dnnl_convolution.
///
/// The number of outputs for primitive remain same as before. The output
/// spatial size can be derived as below:
///
/// output_height = ceil(output_height_1x1_convolution, stride)
/// output_width = ceil(output_width_1x1_convolution, stride)
///
/// The Post-op can be defined as:
///
/// dst[:] <- scales * (conv_dw(conv_1x1))
///
/// See @ref dev_guide_attributes_post_ops_depthwise and
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
///
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the
/// @p scales array. The set i-th bit indicates that a dedicated output
/// scaling factor is used for each index along that dimension. The mask
/// value of 0 implies a common scaling factor for the whole output
/// tensor.
/// @param scales Output pointer to a constant array of float scaling
/// factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise
void append_dw_k3s2p1(memory::data_type weights_data_type,
memory::data_type bias_data_type, memory::data_type dst_data_type,
int mask, const std::vector<float> &scales) {
error::wrap_c_api(dnnl_post_ops_append_dw_k3s2p1(get(),
memory::convert_to_c(weights_data_type),
memory::convert_to_c(bias_data_type),
memory::convert_to_c(dst_data_type),
scales.size(), mask, &scales[0]),
"could not append depthwise post-op");
}
/// Returns the parameters of an depthwise post-op with stride 2.
///
/// @param index Index of the elementwise post-op.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the
/// @p scales array. The set i-th bit indicates that a dedicated output
/// scaling factor is used for each index along that dimension. The mask
/// value of 0 implies a common scaling factor for the whole output
/// tensor.
/// @param scales Output pointer to a constant array of float scaling
/// factors.
void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type,
memory::data_type &bias_data_type, memory::data_type &dst_data_type,
int &mask, std::vector<float> &scales) const {
dnnl_data_type_t c_weights_data_type;
dnnl_data_type_t c_bias_data_type;
dnnl_data_type_t c_dst_data_type;
dnnl_dim_t count;
int c_mask;
const float *c_scales;
error::wrap_c_api(dnnl_post_ops_get_params_dw_k3s2p1(get(), index,
&c_weights_data_type, &c_bias_data_type,
&c_dst_data_type, &count, &c_mask, &c_scales),
"could not get parameters of depthwise post-op");
weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
scales.resize(count);
mask = c_mask;
for (dnnl_dim_t c = 0; c < count; ++c)
scales[c] = c_scales[c];
return;
}
};
/// @cond DO_NOT_DOCUMENT_THIS
template <>
struct handle_traits<dnnl_primitive_attr_t> {
static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
return dnnl_primitive_attr_destroy(p);
}
};
/// @endcond
/// Primitive attributes
///
/// @sa @ref dev_guide_attributes
struct primitive_attr : public handle<dnnl_primitive_attr_t> {
using handle<dnnl_primitive_attr_t>::handle;
/// Constructs default (empty) primitive attributes.
primitive_attr() {
dnnl_primitive_attr_t result;
error::wrap_c_api(dnnl_primitive_attr_create(&result),
"could not create primitive attribute");
reset(result);
}
/// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
/// handle. The resulting handle is not weak and the C handle will be
/// destroyed during the destruction of the C++ object.
///
/// @param attr The C API primitive attributes.
primitive_attr(dnnl_primitive_attr_t attr)
: handle<dnnl_primitive_attr_t>(attr) {}
/// Returns the scratchpad mode.
scratchpad_mode get_scratchpad_mode() const {
dnnl_scratchpad_mode_t result;
error::wrap_c_api(
dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
"could not get scratchpad mode primitive attribute");
return scratchpad_mode(result);
}
/// Sets scratchpad mode.
///
/// @param mode Specified scratchpad mode.
void set_scratchpad_mode(scratchpad_mode mode) {
error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
get(), dnnl::convert_to_c(mode)),
"could not set scratchpad mode primitive attribute");
}
/// Returns output scaling factors correspondence mask and values.
///
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales vector. The set i-th bit indicates that a dedicated output
/// scaling factor is used for each index along that dimension. The
/// mask value of 0 implies a common output scaling factor for the
/// whole output tensor.
/// @param scales Vector of output scaling factors.
void get_output_scales(int &mask, std::vector<float> &scales) const {
dnnl_dim_t count;
int c_mask;
const float *c_scales;
error::wrap_c_api(dnnl_primitive_attr_get_output_scales(
get(), &count, &c_mask, &c_scales),
"could not get output scales primitive attribute");
scales.resize(count);
mask = c_mask;
for (dnnl_dim_t c = 0; c < count; ++c)
scales[c] = c_scales[c];
}
/// Sets output scaling factors correspondence mask and values.
///
/// @note
/// The order of dimensions does not depend on how elements are laid
/// out in memory. For example:
/// - for a 2D CNN activations tensor the order is always (n, c)
/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
/// - for a 5D CNN weights tensor the order is always
/// (g, oc, ic, kh, kw)
///
/// Example usage:
/// @code
/// int mb = 32, oc = 32,
/// oh = 14, ow = 14; // convolution output params
/// // unique output scales per output channel
/// vector<float> scales = { ... };
/// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
///
/// // construct a convolution descriptor
/// dnnl::convolution::desc conv_d;
///
/// dnnl::primitive_attr attr;
/// attr.set_output_scales(attr, oc, 1 << oc_dim, scales);
///
/// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
/// @endcode
///
/// @param mask Defines the correspondence between the output tensor
/// dimensions and the @p scales vector. The set i-th bit indicates
/// that a dedicated scaling factor is used for each index along that
/// dimension. Set the mask to 0 to use a common output scaling factor
/// for the whole output tensor.
/// @param scales Constant vector of output scaling factors. If the
/// scaling factors are known at the time of this call, the following
/// equality must hold:
/// \f[scales.size() = \prod\limits_{d \in mask} output.dims[d].\f]
/// Violations can only be detected when the attributes
/// are used to create a primitive descriptor.
/// If the scaling factors are not known at the time of the call,
/// this vector must contain a single #DNNL_RUNTIME_F32_VAL value and
/// the output scaling factors must be passed at execution time as an
/// argument with index #DNNL_ARG_ATTR_OUTPUT_SCALES.
void set_output_scales(int mask, const std::vector<float> &scales) {
error::wrap_c_api(
dnnl_primitive_attr_set_output_scales(
get(), (dnnl_dim_t)scales.size(), mask, scales.data()),
"could not set output scales primitive attribute");
}
/// Returns scaling factors correspondence mask and values for a given
/// memory argument.
///
/// @param arg Parameter argument index as passed to the
/// primitive::execute() call.
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales vector. The set i-th bit indicates that a dedicated scaling
/// factor is used for each index along that dimension. Set the mask to
/// 0 to use a common scaling factor for the whole output tensor.
/// @param scales Output vector of scaling factors.
void get_scales(int arg, int &mask, std::vector<float> &scales) const {
dnnl_dim_t count;
int c_mask;
const float *c_scales;
error::wrap_c_api(dnnl_primitive_attr_get_scales(
get(), arg, &count, &c_mask, &c_scales),
"could not get scales primitive attributes");
scales.resize(count);
mask = c_mask;
for (dnnl_dim_t c = 0; c < count; ++c)
scales[c] = c_scales[c];
}
/// Sets scaling factors for primitive operations for a given memory
/// argument.
///
/// @sa dnnl_primitive_attr_set_scales
/// @sa dnnl::primitive_attr::set_output_scales
///
/// @param arg Parameter argument index as passed to the
/// primitive::execute() call.
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p scales
/// vector. The set i-th bit indicates that a dedicated scaling factor
/// is used for each index along that dimension. Set the mask to 0 to
/// use a common scaling factor for the whole output tensor.
/// @param scales Constant vector of scaling factors. The following equality
/// must hold:
/// \f[scales.size() = \prod\limits_{d \in mask} argument.dims[d].\f]
void set_scales(int arg, int mask, const std::vector<float> &scales) {
error::wrap_c_api(
dnnl_primitive_attr_set_scales(get(), arg,
(dnnl_dim_t)scales.size(), mask, scales.data()),
"could not set scales primitive attribute");
}
/// Returns zero points correspondence mask and values.
///
/// @param arg Parameter argument index as passed to the
/// primitive::execute() call.
/// @param mask Zero points correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// zero_points vector. The set i-th bit indicates that a dedicated
/// zero point is used for each index along that dimension. Set the
/// mask to 0 to use a common zero point for the whole output tensor.
/// @param zero_points Output vector of zero points.
void get_zero_points(
int arg, int &mask, std::vector<int32_t> &zero_points) const {
dnnl_dim_t count;
int c_mask;
const int32_t *c_zero_points;
error::wrap_c_api(dnnl_primitive_attr_get_zero_points(
get(), arg, &count, &c_mask, &c_zero_points),
"could not get zero points primitive attribute");
zero_points.resize(count);
mask = c_mask;
for (dnnl_dim_t c = 0; c < count; ++c)
zero_points[c] = c_zero_points[c];
}
/// Sets zero points for primitive operations for a given memory argument.
///
/// @sa dnnl_primitive_attr_set_zero_points
/// @sa dnnl::primitive_attr::set_output_scales
///
/// @param arg Parameter argument index as passed to the
/// primitive::execute() call.
/// @param mask Zero point correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p
/// zero_points vector. The set i-th bit indicates that a dedicated
/// zero point is used for each index along that dimension. Set the
/// mask to 0 to use a common zero point for the whole output tensor.
/// @param zero_points Constant vector of zero points. If the zero points
/// are known at the time of this call, the following equality must
/// hold:
/// \f[zero_points.size() = \prod\limits_{d \in mask} argument.dims[d].\f]
/// If the zero points are not known at the time of the call, this
/// vector must contain a single #DNNL_RUNTIME_F32_VAL value and the
/// zero points must be passed at execution time as an argument with
/// index #DNNL_ARG_ATTR_ZERO_POINTS.
void set_zero_points(
int arg, int mask, const std::vector<int32_t> &zero_points) {
error::wrap_c_api(dnnl_primitive_attr_set_zero_points(get(), arg,
(dnnl_dim_t)zero_points.size(), mask,
zero_points.data()),
"could not set zero points primitive attribute");
}
/// Returns post-ops previously set via set_post_ops().
///
/// @returns Post-ops.
const post_ops get_post_ops() const {
post_ops result;
const_dnnl_post_ops_t c_result;
error::wrap_c_api(dnnl_primitive_attr_get_post_ops(get(), &c_result),
"could not get post-ops primitive attribute");
result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
return result;
}
/// Sets post-ops.
///
/// @note
/// There is no way to check whether the post-ops would be supported
/// by the target primitive. Any error will be reported
/// by the respective primitive descriptor constructor.
///
/// @param ops Post-ops object to copy post-ops from.
void set_post_ops(const post_ops ops) {
error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
"could not set post-ops primitive attribute");
}
/// Sets quantization scale and shift parameters for RNN data tensors.
///
/// For performance reasons, the low-precision configuration of the RNN
/// primitives expect input activations to have the unsigned 8-bit integer
/// data type. The scale and shift parameters are used to quantize
/// floating-point data to unsigned integer and must be passed to the RNN
/// primitive using attributes.
///
/// The quantization formula is `scale * (data + shift)`.
///
/// @note
/// Quantization scale and shift are common for src_layer, src_iter,
/// dst_iter, and dst_layer.
///
/// Example usage:
/// @code
/// // RNN parameters
/// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
/// // Activations quantization parameters
/// float scale = ..., shift = ..;
///
/// primitive_attr attr;
///
/// // Set scale and shift for int8 quantization of activation
/// attr.set_rnn_data_qparams(scale, shift);
///
/// // Create and configure rnn op_desc
/// vanilla_rnn_forward::desc rnn_d(...);
/// vanilla_rnn_forward::primitive_desc rnn_d(rnn_d, attr, engine);
/// @endcode
///
/// @param scale The value to scale the data by.
/// @param shift The value to shift the data by.
void set_rnn_data_qparams(float scale, float shift) {
error::wrap_c_api(
dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
"could not get RNN data quantization parameters primitive "
"attribute");
}
/// Sets quantization scaling factors for RNN weights tensors. The
/// low-precision configuration of the RNN primitives expect input weights
/// to use the signed 8-bit integer data type. The scaling factors are
/// used to quantize floating-point data to signed integer and must be
/// passed to RNN primitives using attributes.
///
/// @note
/// The dimension order is always native and does not depend on the
/// actual layout used. For example, five-dimensional weights always
/// have (l, d, i, g, o) logical dimension ordering.
///
/// @note
/// Quantization scales are common for weights_layer and
/// weights_iteration
///
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales vector. The set i-th bit indicates that a dedicated scaling
/// factor should be used each index along that dimension. Set the
/// mask to 0 to use a common scaling factor for the whole output
/// tensor.
/// @param scales Constant vector of output scaling factors. The following
/// equality must hold:
/// \f[scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f]
/// Violations can only be detected when the attributes are used to
/// create a primitive descriptor.
void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
(int)scales.size(), mask, scales.data()),
"could not get RNN weights quantization parameters primitive "
"attribute");
}
};
/// @} dnnl_api_attributes
/// @addtogroup dnnl_api_primitives_common
/// @{
/// Base class for all primitive descriptors.
struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
using handle<dnnl_primitive_desc_t>::handle;
/// Default constructor. Produces an empty object.
primitive_desc_base() = default;
/// Returns the engine of the primitive descriptor.
/// @returns The engine of the primitive descriptor.
engine get_engine() const { return engine::query(*this); }
/// Returns implementation name.
/// @returns The implementation name.
const char *impl_info_str() const {
const char *res;
error::wrap_c_api(dnnl_primitive_desc_query(
get(), dnnl_query_impl_info_str, 0, &res),
"could not retrieve implementation info string from a "
"primitive descriptor");
return res;
}
/// Returns a memory::dim value (same as int64_t).
/// @param what The value to query.
/// @returns The result of the query.
memory::dim query_s64(query what) const {
memory::dim res;
dnnl_status_t status = dnnl_primitive_desc_query(
get(), dnnl::convert_to_c(what), 0, &res);
return status == dnnl_success ? res : 0;
}
/// Returns a memory descriptor.
///
/// @note
/// See also the convenience methods
/// dnnl::primitive_desc_base::src_desc(),
/// dnnl::primitive_desc_base::dst_desc(), and others.
///
/// @param what The kind of parameter to query; can be
/// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
/// @param idx Index of the parameter. For example, convolution bias can
/// be queried with what = #dnnl::query::weights_md and idx = 1.
/// @returns The requested memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// parameter of the specified kind or index.
memory::desc query_md(query what, int idx = 0) const {
std::vector<query> valid_q {query::src_md, query::diff_src_md,
query::weights_md, query::diff_weights_md, query::dst_md,
query::diff_dst_md, query::workspace_md, query::scratchpad_md,
query::exec_arg_md};
if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
[=](query q) { return what == q; }))
DNNL_THROW_ERROR(dnnl_invalid_arguments,
"memory descriptor query is invalid");
const dnnl_memory_desc_t *cdesc = dnnl_primitive_desc_query_md(
get(), dnnl::convert_to_c(what), idx);
return cdesc ? memory::desc(*cdesc) : memory::desc();
}
/// Returns a source memory descriptor.
/// @param idx Source index.
/// @returns Source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// source parameter with index @p pdx.
memory::desc src_desc(int idx) const {
return query_md(query::src_md, idx);
}
/// Returns a destination memory descriptor.
/// @param idx Destination index.
/// @returns Destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// destination parameter with index @p pdx.
memory::desc dst_desc(int idx) const {
return query_md(query::dst_md, idx);
}
/// Returns a weights memory descriptor.
/// @param idx Weights index.
/// @returns Weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// weights parameter with index @p pdx.
memory::desc weights_desc(int idx) const {
return query_md(query::weights_md, idx);
}
/// Returns a diff source memory descriptor.
/// @param idx Diff source index.
/// @returns Diff source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff source parameter with index @p pdx.
memory::desc diff_src_desc(int idx) const {
return query_md(query::diff_src_md, idx);
}
/// Returns a diff destination memory descriptor.
/// @param idx Diff destination index.
/// @returns Diff destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff destination parameter with index @p pdx.
memory::desc diff_dst_desc(int idx) const {
return query_md(query::diff_dst_md, idx);
}
/// Returns a diff weights memory descriptor.
/// @param idx Diff weights index.
/// @returns Diff weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff weights parameter with index @p pdx.
memory::desc diff_weights_desc(int idx) const {
return query_md(query::diff_weights_md, idx);
}
// Separate versions without the index argument for documentation
// purposes.
/// Returns a source memory descriptor.
/// @returns Source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// source parameter.
memory::desc src_desc() const { return src_desc(0); }
/// Returns a destination memory descriptor.
/// @returns Destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// destination parameter.
memory::desc dst_desc() const { return dst_desc(0); }
/// Returns a weights memory descriptor.
/// @returns Weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// weights parameter.
memory::desc weights_desc() const { return weights_desc(0); }
/// Returns a diff source memory descriptor.
/// @returns Diff source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff source memory with.
memory::desc diff_src_desc() const { return diff_src_desc(0); }
/// Returns a diff destination memory descriptor.
/// @returns Diff destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff destination parameter.
memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
/// Returns a diff weights memory descriptor.
/// @returns Diff weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff weights parameter.
memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
/// Returns the workspace memory descriptor.
/// @returns Workspace memory descriptor.
/// @returns A zero memory descriptor if the primitive does not require
/// workspace parameter.
memory::desc workspace_desc() const {
return query_md(query::workspace_md, 0);
}
/// Returns the scratchpad memory descriptor.
/// @returns scratchpad memory descriptor.
/// @returns A zero memory descriptor if the primitive does not require
/// scratchpad parameter.
/// @sa @ref dev_guide_attributes_scratchpad
memory::desc scratchpad_desc() const {
return query_md(query::scratchpad_md, 0);
}
/// Returns the engine on which the scratchpad memory is located.
/// @returns The engine on which the scratchpad memory is located.
engine scratchpad_engine() const {
dnnl_engine_t c_engine;
error::wrap_c_api(dnnl_primitive_desc_query(get(),
dnnl::convert_to_c(query::scratchpad_engine),
0, &c_engine),
"could not retrieve scratchpad engine from a primitive "
"descriptor");
return engine(c_engine, true);
}
/// Returns the primitive attributes.
/// @returns The primitive attributes.
primitive_attr get_primitive_attr() const {
const_dnnl_primitive_attr_t const_c_attr;
error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
"could not get attributes from a primitive descriptor");
dnnl_primitive_attr_t c_attr;
error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
"could not clone primitive attributes");
return primitive_attr(c_attr);
}
/// Returns the kind of the primitive descriptor.
/// @returns The kind of the primitive descriptor.
dnnl::primitive::kind get_kind() const {
dnnl_primitive_kind_t kind;
error::wrap_c_api(dnnl_primitive_desc_query(get(),
dnnl_query_primitive_kind, 0, (void *)&kind),
"could not get primitive kind from a primitive descriptor");
return static_cast<dnnl::primitive::kind>(kind);
}
protected:
/// Resets the value of the handle to a clone of a C API primitive
/// descriptor.
/// @param pd A C API primitive descriptor to clone.
void reset_with_clone(const_dnnl_primitive_desc_t pd) {
dnnl_primitive_desc_t new_pd;
error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
"could not clone a primitive descriptor");
reset(new_pd);
}
/// Constructs a primitive descriptor base object from a clone of a C API
/// primitive descriptor after verifying that it is what the caller
/// expects.
///
/// @note
/// The @p prim_kind should map to a primitive that does not have
/// different values of propagation kind (e.g. #dnnl::binary).
/// @note
/// Primitive descriptor base constructed this way does not support
/// next_impl() (will throw).
///
/// @param pd C API primitive descriptor to clone.
/// @param prim_kind Expected primitive kind.
primitive_desc_base(
dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
: primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
/// Constructs a primitive descriptor base object from a clone of a C API
/// primitive descriptor after verifying that it is what the caller
/// expects.
///
/// @note
/// Primitive descriptor base constructed this way does not support
/// next_impl() (will throw).
///
/// @param pd C API primitive descriptor to clone.
/// @param prim_kind Expected primitive kind.
/// @param prop_kind Expected propagation kind.
primitive_desc_base(dnnl_primitive_desc_t pd,
dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind)
: primitive_desc_base(pd, prim_kind, prop_kind, prop_kind) {}
/// Constructs a primitive descriptor base object from a clone of a C API
/// primitive descriptor after verifying that it is what the caller
/// expects.
///
/// @note
/// Primitive descriptor base constructed this way does not support
/// next_impl() (will throw).
///
/// @param pd C API primitive descriptor to clone.
/// @param prim_kind Expected primitive kind.
/// @param prop_kind1 Expected propagation kind (option 1).
/// @param prop_kind2 Expected propagation kind (option 2). This value is
/// checked if the check with @p prop_kind1 fails.
primitive_desc_base(dnnl_primitive_desc_t pd,
dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
dnnl::prop_kind prop_kind2) {
// It is OK to pass an empty primitive descriptor
if (pd == nullptr) return;
dnnl_status_t rc;
dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
// Check that primitive kind matches
dnnl_primitive_kind_t pd_kind;
rc = dnnl_primitive_desc_query(
pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
error::wrap_c_api(
rc, "could not get primitive kind from a primitive descriptor");
if (pd_kind != c_prim_kind)
DNNL_THROW_ERROR(dnnl_invalid_arguments,
"primitive descriptor operation kind mismatch");
// Check that propagation kind matches
dnnl_prop_kind_t pd_prop_kind;
rc = dnnl_primitive_desc_query(
pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
// Something went wrong
if (rc != dnnl_success && rc != dnnl_unimplemented)
DNNL_THROW_ERROR(dnnl_invalid_arguments,
"could not get propagation kind from the primitive "
"descriptor");
// Everything is fine
if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
|| (rc == dnnl_success
&& (pd_prop_kind == c_prop_kind1
|| pd_prop_kind == c_prop_kind2))) {
reset_with_clone(pd);
return;
}
// We could get the propagation kind but there is a mismatch
DNNL_THROW_ERROR(dnnl_invalid_arguments,
"primitive descriptor propagation kind mismatch");
}
using base = primitive_desc_base;
};
/// @} dnnl_api_primitives_common
/// @addtogroup dnnl_api_reorder Reorder
///
/// A primitive to copy data between two memory objects. This primitive is
/// typically used to change the way the data is laid out in memory.
///
/// @sa @ref dev_guide_reorder in developer guide
///
/// @{
/// Reorder primitive.
struct reorder : public primitive {
/// Primitive descriptor for a reorder primitive.
struct primitive_desc : public primitive_desc_base {
using primitive_desc_base::primitive_desc_base;
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for reorder primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param src_engine Engine on which the source memory object will be
/// located.
/// @param src_md Source memory descriptor.
/// @param dst_engine Engine on which the destination memory object
/// will be located.
/// @param dst_md Destination memory descriptor.
/// @param attr Primitive attributes to use (optional).
primitive_desc(const engine &src_engine, const memory::desc &src_md,
const engine &dst_engine, const memory::desc &dst_md,
const primitive_attr &attr = primitive_attr()) {
dnnl_primitive_desc_t result;
error::wrap_c_api(
dnnl_reorder_primitive_desc_create(&result, &src_md.data,
src_engine.get(), &dst_md.data, dst_engine.get(),
attr.get()),
"could not create a primitive descriptor for a reorder "
"primitive");
reset(result);
}
/// Constructs a primitive descriptor for reorder primitive.
///
/// @param src Source memory object. It is used to obtain the source
/// memory descriptor and engine.
/// @param dst Destination memory object. It is used to obtain the
/// destination memory descriptor and engine.
/// @param attr Primitive attributes to use (optional).
primitive_desc(const memory &src, const memory &dst,
const primitive_attr &attr = primitive_attr()) {
dnnl_primitive_desc_t result;
auto src_md = src.get_desc();
auto dst_md = dst.get_desc();
error::wrap_c_api(
dnnl_reorder_primitive_desc_create(&result, &src_md.data,
src.get_engine().get(), &dst_md.data,
dst.get_engine().get(), attr.get()),
"could not create a primitive descriptor for a reorder "
"primitive");
reset(result);
}
/// Constructs a primitive descriptor for reorder primitive from a C
/// API primitive descriptor which must have a matching kind.
///
/// @param pd C API primitive descriptor for reorder primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
/// Returns the engine on which the source memory is allocated.
/// @returns The engine on which the source memory is allocated.
engine get_src_engine() const {
return engine::query(*this, dnnl::query::reorder_src_engine);
}
/// Returns the engine on which the destination memory is allocated.
/// @returns The engine on which the destination memory is allocated.
engine get_dst_engine() const {
return engine::query(*this, dnnl::query::reorder_dst_engine);
}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
reorder() = default;
/// Constructs a reorder primitive.
/// @param pd Primitive descriptor for reorder primitive.
reorder(const primitive_desc &pd) : primitive(pd.get()) {}
/// Constructs a reorder primitive that would reorder data between memory
/// objects having the same memory descriptors as memory objects @p src and
/// @p dst.
///
/// @param src Source memory object.
/// @param dst Destination memory object.
/// @param attr Primitive attributes to use (optional).
reorder(const memory &src, const memory &dst,
const primitive_attr &attr = primitive_attr())
: primitive(primitive_desc(src, dst, attr).get()) {}
using primitive::execute;
/// Executes the reorder primitive.
///
/// @param stream Stream object. The stream must belong to the same engine
/// as the primitive.
/// @param src Source memory object.
/// @param dst Destination memory object.
void execute(const stream &stream, memory &src, memory &dst) const {
primitive::execute(stream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
}
};
/// @} dnnl_api_reorder
/// @addtogroup dnnl_api_concat Concat
///
/// A primitive to concatenate data by arbitrary dimension.
///
/// @sa @ref dev_guide_concat in developer guide
///
/// @{
/// @cond DO_NOT_DOCUMENT_THIS
inline std::vector<dnnl_memory_desc_t> convert_to_c(
const std::vector<memory::desc> &mems) {
std::vector<dnnl_memory_desc_t> c_mems;
c_mems.reserve(mems.size());
for (const auto &s : mems)
c_mems.push_back(s.data);
return c_mems;
}
/// @endcond
/// Tensor concatenation (concat) primitive.
struct concat : public primitive {
/// Primitive descriptor for a concat primitive.
struct primitive_desc : public primitive_desc_base {
using primitive_desc_base::primitive_desc_base;
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an out-of-place concatenation
/// primitive.
///
/// Inputs:
/// - `src[0]` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src[1]` (#dnnl::primitive_desc_base::src_desc(`1`))
/// - ...
/// - `src[n - 1]` (#dnnl::primitive_desc_base::src_desc(`n - 1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param dst Destination memory descriptor.
/// @param concat_dimension Source tensors will be concatenated over
/// dimension with this index. Note that order of dimensions does
/// not depend on memory format.
/// @param srcs Vector of source memory descriptors.
/// @param engine Engine to perform the operation on.
/// @param attr Primitive attributes to use (optional).
primitive_desc(const memory::desc &dst, int concat_dimension,
const std::vector<memory::desc> &srcs, const engine &engine,
const primitive_attr &attr = primitive_attr()) {
auto c_srcs = convert_to_c(srcs);
dnnl_primitive_desc_t result;
error::wrap_c_api(
dnnl_concat_primitive_desc_create(&result, &dst.data,
(int)c_srcs.size(), concat_dimension, c_srcs.data(),
attr.get(), engine.get()),
"could not create a primitive descriptor for a concat "
"primitive");
reset(result);
}
/// Constructs a primitive descriptor for an out-of-place concatenation
/// primitive.
///
/// This version derives the destination memory descriptor
/// automatically.
///
/// @param concat_dimension Source tensors will be concatenated over
/// dimension with this index. Note that order of dimensions does
/// not depend on memory format.
/// @param srcs Vector of source memory descriptors.
/// @param engine Engine to perform the operation on.
/// @param attr Primitive attributes to use (optional).
primitive_desc(int concat_dimension,
const std::vector<memory::desc> &srcs, const engine &engine,
const primitive_attr &attr = primitive_attr()) {
auto c_api_srcs = convert_to_c(srcs);
dnnl_primitive_desc_t result;
error::wrap_c_api(
dnnl_concat_primitive_desc_create(&result, nullptr,
(int)c_api_srcs.size(), concat_dimension,
c_api_srcs.data(), attr.get(), engine.get()),
"could not create a primitive descriptor for a concat "
"primitive");
reset(result);
}
/// Constructs a primitive descriptor for concat primitive from a C
/// API primitive descriptor which must have a matching kind.
///
/// @param pd C API primitive descriptor for concat primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
concat() = default;
/// Constructs a concatenation primitive.
/// @param pd Primitive descriptor for concatenation primitive.
concat(const primitive_desc &pd) : primitive(pd.get()) {}
};
/// @} dnnl_api_concat
/// @addtogroup dnnl_api_sum Sum
///
/// A primitive to sum multiple tensors.
///
/// @sa @ref dev_guide_sum in developer guide
///
/// @{
/// Out-of-place summation (sum) primitive.
struct sum : public primitive {
/// Primitive descriptor for a sum primitive.
struct primitive_desc : public primitive_desc_base {
using primitive_desc_base::primitive_desc_base;
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a sum primitive.
///
/// Inputs:
/// - `src[0]` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src[1]` (#dnnl::primitive_desc_base::src_desc(`1`))
/// - ...
/// - `src[n - 1]` (#dnnl::primitive_desc_base::src_desc(`n - 1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param dst Destination memory descriptor.
/// @param scales Vector of scales to multiply data in each source
/// memory by.
/// @param srcs Vector of source memory descriptors.
/// @param engine Engine to perform the operation on.
/// @param attr Primitive attributes to use (optional).
primitive_desc(const memory::desc &dst,
const std::vector<float> &scales,
const std::vector<memory::desc> &srcs, const engine &engine,
const primitive_attr &attr = primitive_attr()) {
validate_container_size(scales,
"counts of scales and sources are not equal",
(int)srcs.size(), (int)srcs.size());
auto c_api_srcs = convert_to_c(srcs);
dnnl_primitive_desc_t result;
error::wrap_c_api(
dnnl_sum_primitive_desc_create(&result, &dst.data,
(int)c_api_srcs.size(), scales.data(),
c_api_srcs.data(), attr.get(), engine.get()),
"could not create a primitive descriptor for a sum "
"primitive");
reset(result);
}
/// Constructs a primitive descriptor for a sum primitive.
///
/// This version derives the destination memory descriptor
/// automatically.
///
/// @param scales Vector of scales by which to multiply data in each
/// source memory object.
/// @param srcs Vector of source memory descriptors.
/// @param engine Engine on which to perform the operation.
/// @param attr Primitive attributes to use (optional).
primitive_desc(const std::vector<float> &scales,
const std::vector<memory::desc> &srcs, const engine &engine,
const primitive_attr &attr = primitive_attr()) {
validate_container_size(scales,
"counts of scales and sources are not equal",
(int)srcs.size(), (int)srcs.size());
auto c_api_srcs = convert_to_c(srcs);
dnnl_primitive_desc_t result;
error::wrap_c_api(
dnnl_sum_primitive_desc_create(&result, nullptr,
(int)c_api_srcs.size(), scales.data(),
c_api_srcs.data(), attr.get(), engine.get()),
"could not create a primitive descriptor for a sum "
"primitive");
reset(result);
}
/// Constructs a primitive descriptor for sum primitive from a C API
/// primitive descriptor which must have a matching kind.
///
/// @param pd C API primitive descriptor for reorder primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
sum() = default;
/// Constructs a sum primitive.
/// @param pd Primitive descriptor for sum primitive.
sum(const primitive_desc &pd) : primitive(pd.get()) {}
};
/// @} dnnl_api_sum
/// @addtogroup dnnl_api_primitives_common
/// @{
/// A base class for descriptors of all primitives that have an operation
/// descriptor and that support iteration over multiple implementations.
struct primitive_desc : public primitive_desc_base {
using primitive_desc_base::primitive_desc_base;
primitive_desc() = default;
/// Constructs a primitive descriptor.
///
/// @note
/// If @p allow_empty is true, the constructor does not throw if a
/// primitive descriptor cannot be created. But calling next_impl() in
/// this case will throw.
///
/// @note
/// This is a low-level implementation detail that is typically not
/// needed in application code.
///
/// @param desc Constant C API operation descriptor.
/// @param attr Pointer to primitive attributes. It is safe to pass
/// nullptr to indicate absence of attributes.
/// @param engine Engine to use.
/// @param hint_fwd_pd C API primitive descriptor for a forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use for backward propagation or weights gradient.
/// @param allow_empty A flag signifying whether construction is allowed
/// to fail without throwing an exception. In this case an empty
/// object will be produced. This flag is optional and defaults to
/// false.
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr,
const engine &engine, const_dnnl_primitive_desc_t hint_fwd_pd,
bool allow_empty = false)
: allow_empty_(allow_empty) {
dnnl_primitive_desc_iterator_t iterator = nullptr;
dnnl_status_t status = dnnl_primitive_desc_iterator_create(&iterator,
desc, attr ? attr->get() : nullptr, engine.get(), hint_fwd_pd);
if (!allow_empty)
error::wrap_c_api(
status, "could not create a primitive descriptor iterator");
pd_iterator.reset(iterator);
fetch_impl();
}
/// Advances the primitive iterator to the next implementation.
///
/// @returns @c true on success, and @c false if the last implementation
/// reached, and the primitive descriptor itself is kept unchanged
bool next_impl() {
dnnl_status_t status
= dnnl_primitive_desc_iterator_next(pd_iterator.get());
if (status == dnnl_iterator_ends) return false;
error::wrap_c_api(
status, "could not advance a primitive descriptor iterator");
fetch_impl();
return true;
}
private:
bool allow_empty_ = false;
handle<dnnl_primitive_desc_iterator_t> pd_iterator;
void fetch_impl() {
dnnl_primitive_desc_t pd = dnnl_primitive_desc_iterator_fetch(
pd_iterator.get(allow_empty_));
error::wrap_c_api(pd != nullptr || allow_empty_ ? dnnl_success
: dnnl_runtime_error,
"could not fetch a primitive descriptor from a primitive "
"descriptor iterator");
reset(pd);
}
};
/// @} dnnl_api_primitives_common
/// @addtogroup dnnl_api_convolution Convolution
///
/// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
/// forward propagation, backward propagation, and weights gradient with or
/// without bias.
///
/// @sa @ref dev_guide_convolution in developer guide
///
/// @{
/// Convolution forward propagation primitive.
struct convolution_forward : public primitive {
/// Descriptor for a convolution forward propagation primitive.
struct desc {
dnnl_convolution_desc_t data;
/// Constructs a descriptor for a convolution forward propagation
/// primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing zero memory
/// descriptor disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &bias_desc, const memory::desc &dst_desc,
const memory::dims &strides, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_convolution_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, &bias_desc.data, &dst_desc.data,
&strides[0], &padding_l[0], &padding_r[0]),
"could not create a descriptor for a convolution forward "
"propagation primitive");
}
/// Constructs a descriptor for a convolution forward propagation
/// primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_convolution_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, nullptr, &dst_desc.data,
&strides[0], &padding_l[0], &padding_r[0]),
"could not create a descriptor for a convolution forward "
"propagation primitive");
}
/// Constructs a descriptor for a dilated convolution forward
/// propagation primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing zero memory
/// descriptor disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &bias_desc, const memory::desc &dst_desc,
const memory::dims &strides, const memory::dims &dilates,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_dilated_convolution_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, &bias_desc.data,
&dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated convolution "
"forward propagation primitive");
}
/// Constructs a descriptor for a dilated convolution forward
/// propagation primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_dilated_convolution_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, nullptr,
&dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated convolution "
"forward propagation primitive");
}
};
/// Primitive descriptor for a convolution forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a convolution forward
/// propagation primitive.
///
/// @param desc Descriptor for a convolution forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a convolution forward
/// propagation primitive.
///
/// @param desc Descriptor for a convolution forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a convolution forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a convolution forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// Returns the bias memory descriptor.
/// @returns The bias memory descriptor.
/// @returns A zero memory descriptor of the primitive does not have a
/// bias parameter.
memory::desc bias_desc() const { return base::weights_desc(1); }
};
/// Default constructor. Produces an empty object.
convolution_forward() = default;
/// Constructs a convolution forward propagation primitive.
/// @param pd Primitive descriptor for a convolution forward propagation
/// primitive.
convolution_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Convolution backward propagation primitive.
struct convolution_backward_data : public primitive {
/// Descriptor for a convolution backward propagation primitive.
struct desc {
dnnl_convolution_desc_t data;
/// Constructs a descriptor for a convolution backward propagation
/// primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &diff_src_desc,
const memory::desc &weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_convolution_backward_data_desc_init(&data,
convert_to_c(algorithm), &diff_src_desc.data,
&weights_desc.data, &diff_dst_desc.data,
&strides[0], &padding_l[0], &padding_r[0]),
"could not create a descriptor for a convolution backward "
"propagation primitive");
}
/// Constructs a descriptor for dilated convolution backward
/// propagation primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @note
/// Memory descriptors are allowed to be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &diff_src_desc,
const memory::desc &weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_dilated_convolution_backward_data_desc_init(&data,
convert_to_c(algorithm), &diff_src_desc.data,
&weights_desc.data, &diff_dst_desc.data,
&strides[0], &dilates[0], &padding_l[0],
&padding_r[0]),
"could not create a descriptor for a dilated convolution "
"backward propagation primitive");
}
};
/// Primitive descriptor for a convolution backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a convolution backward
/// propagation primitive.
///
/// @param desc Descriptor for a convolution backward propagation
/// primitive.
/// @param engine Engine to perform the operation on.
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const convolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a convolution backward
/// propagation primitive.
///
/// @param desc Descriptor for a convolution backward propagation
/// primitive.
/// @param engine Engine to perform the operation on.
/// @param attr Primitive attributes to use.
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const convolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a convolution backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a convolution backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
convolution_backward_data() = default;
/// Constructs a convolution backward propagation primitive.
/// @param pd Primitive descriptor for a convolution backward propagation
/// primitive.
convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
};
/// Convolution weights gradient primitive.
struct convolution_backward_weights : public primitive {
/// Descriptor for a convolution weights gradient primitive.
struct desc {
dnnl_convolution_desc_t data;
/// Constructs a descriptor for a convolution weights gradient primitive
/// with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
/// memory descriptor disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_convolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, &diff_bias_desc.data,
&diff_dst_desc.data, &strides[0], &padding_l[0],
&padding_r[0]),
"could not create a descriptor for a convolution weights "
"update primitive");
}
/// Constructs a descriptor for a convolution weights gradient primitive
/// without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_convolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, nullptr,
&diff_dst_desc.data, &strides[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a convolution weights "
"update primitive");
}
/// Constructs a descriptor for a dilated convolution weights gradient
/// primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
/// memory descriptor disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_dilated_convolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, &diff_bias_desc.data,
&diff_dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated convolution "
"weights gradient primitive");
}
/// Constructs a descriptor for a dilated convolution weights gradient
/// primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Convolution algorithm. Possible values are
/// #dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd, and
/// #dnnl::algorithm::convolution_auto.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_dilated_convolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, nullptr,
&diff_dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated convolution "
"weights gradient primitive");
}
};
/// Primitive descriptor for a convolution weights gradient primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a convolution weights gradient
/// primitive.
///
/// @param desc Descriptor for a convolution weights gradient primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const convolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a convolution weights gradient
/// primitive.
///
/// @param desc Descriptor for a convolution weights gradient primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a convolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const convolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a convolution weights gradient
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for a convolution weights
/// gradient primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
dnnl::prop_kind::backward_weights) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
memory::desc diff_weights_desc() const {
return base::diff_weights_desc(0);
}
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// Returns the diff bias memory descriptor.
/// @returns The diff bias memory descriptor.
/// @returns A zero memory descriptor of the primitive does not have a
/// diff bias parameter.
memory::desc diff_bias_desc() const {
return base::diff_weights_desc(1);
}
};
/// Default constructor. Produces an empty object.
convolution_backward_weights() = default;
/// Constructs a convolution weights gradient primitive.
/// @param pd Primitive descriptor for a convolution weights gradient
/// primitive.
convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_convolution
//
/// @addtogroup dnnl_api_deconvolution Deconvolution
///
/// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
/// forward propagation, backward propagation, and weights gradient with or
/// without bias.
///
/// @{
/// Deconvolution forward propagation primitive.
struct deconvolution_forward : public primitive {
/// Descriptor for a deconvolution forward propagation primitive.
struct desc {
dnnl_deconvolution_desc_t data;
/// Constructs a descriptor for a deconvolution forward propagation
/// primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Deconvolution algorithm:
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing zero memory
/// descriptor disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param strides Vector of strides for spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &bias_desc, const memory::desc &dst_desc,
const memory::dims &strides, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_deconvolution_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, &bias_desc.data, &dst_desc.data,
&strides[0], &padding_l[0], &padding_r[0]),
"could not create a descriptor for a deconvolution forward "
"propagation primitive");
}
/// Constructs a descriptor for a deconvolution forward propagation
/// primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Deconvolution algorithm:
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param strides Vector of strides for spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_deconvolution_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, nullptr, &dst_desc.data,
&strides[0], &padding_l[0], &padding_r[0]),
"could not create a descriptor for a deconvolution forward "
"propagation primitive");
}
/// Constructs a descriptor for a dilated deconvolution forward
/// propagation primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Deconvolution algorithm:
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param bias_desc Bias memory descriptor. Passing zero memory
/// descriptor disables the bias term.
/// @param dst_desc Destination memory descriptor.
/// @param strides Vector of strides for spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &bias_desc, const memory::desc &dst_desc,
const memory::dims &strides, const memory::dims &dilates,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_dilated_deconvolution_forward_desc_init(
&data, dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, &bias_desc.data,
&dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated deconvolution "
"forward propagation primitive");
}
/// Constructs a descriptor for a dilated deconvolution forward
/// propagation primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Deconvolution algorithm:
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param strides Vector of strides for spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_dilated_deconvolution_forward_desc_init(
&data, dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&weights_desc.data, nullptr,
&dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated deconvolution "
"forward propagation primitive");
}
};
/// Primitive descriptor for a deconvolution forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a deconvolution forward
/// propagation primitive.
///
/// @param desc Descriptor for a deconvolution forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a deconvolution forward
/// propagation primitive.
///
/// @param desc Descriptor for a deconvolution forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a deconvolution forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a deconvolution forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
memory::desc bias_desc() const { return base::weights_desc(1); }
};
/// Default constructor. Produces an empty object.
deconvolution_forward() = default;
/// Constructs a deconvolution forward propagation primitive.
/// @param pd Primitive descriptor for a deconvolution forward propagation
/// primitive.
deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Deconvolution backward propagation primitive.
struct deconvolution_backward_data : public primitive {
/// Descriptor for a deconvolution backward propagation primitive.
struct desc {
dnnl_deconvolution_desc_t data;
/// Constructs a descriptor for a deconvolution backward propagation
/// primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Deconvolution algorithm
/// (#dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd).
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &diff_src_desc,
const memory::desc &weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_deconvolution_backward_data_desc_init(&data,
convert_to_c(algorithm), &diff_src_desc.data,
&weights_desc.data, &diff_dst_desc.data,
&strides[0], &padding_l[0], &padding_r[0]),
"could not create a descriptor for a deconvolution "
"backward propagation primitive");
}
/// Constructs a descriptor for a dilated deconvolution backward
/// propagation primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Deconvolution algorithm
/// (#dnnl::algorithm::convolution_direct,
/// #dnnl::algorithm::convolution_winograd).
/// @param diff_src_desc Diff source memory descriptor.
/// @param weights_desc Weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &diff_src_desc,
const memory::desc &weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_dilated_deconvolution_backward_data_desc_init(&data,
convert_to_c(algorithm), &diff_src_desc.data,
&weights_desc.data, &diff_dst_desc.data,
&strides[0], &dilates[0], &padding_l[0],
&padding_r[0]),
"could not create a descriptor for a dilated deconvolution "
"backward propagation primitive");
}
};
/// Primitive descriptor for a deconvolution backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a deconvolution backward
/// propagation primitive.
///
/// @param desc descriptor for a deconvolution backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const deconvolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a deconvolution backward
/// propagation primitive.
///
/// @param desc descriptor for a deconvolution backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const deconvolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a deconvolution backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a deconvolution backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
deconvolution_backward_data() = default;
/// Constructs a deconvolution backward propagation primitive.
/// @param pd Primitive descriptor for a deconvolution backward propagation
/// primitive.
deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
};
/// Deconvolution weights gradient primitive.
struct deconvolution_backward_weights : public primitive {
/// Descriptor for a deconvolution weights gradient primitive.
struct desc {
dnnl_deconvolution_desc_t data;
/// Constructs a descriptor for a deconvolution weights gradient
/// primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Deconvolution algorithm. Possible values are
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
/// memory descriptor disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_deconvolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, &diff_bias_desc.data,
&diff_dst_desc.data, &strides[0], &padding_l[0],
&padding_r[0]),
"could not create a descriptor for a deconvolution weights "
"update primitive");
}
/// Constructs a descriptor for a deconvolution weights gradient primitive
/// without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Deconvolution algorithm. Possible values are
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_deconvolution_backward_weights_desc_init(
&data, convert_to_c(algorithm),
&src_desc.data, &diff_weights_desc.data,
nullptr, &diff_dst_desc.data, &strides[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a deconvolution weights "
"update primitive");
}
/// Constructs a descriptor for a dilated deconvolution weights gradient
/// primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Deconvolution algorithm. Possible values are
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
/// memory descriptor disables the bias term.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_dilated_deconvolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, &diff_bias_desc.data,
&diff_dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated deconvolution "
"weights gradient primitive");
}
/// Constructs a descriptor for a dilated deconvolution weights gradient
/// primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param algorithm Deconvolution algorithm. Possible values are
/// #dnnl::algorithm::deconvolution_direct, and
/// #dnnl::algorithm::deconvolution_winograd.
/// @param src_desc Source memory descriptor.
/// @param diff_weights_desc Diff weights memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Strides for each spatial dimension.
/// @param dilates Dilations for each spatial dimension. A zero value
/// means no dilation in the corresponding dimension.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &dilates, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(dilates, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_dilated_deconvolution_backward_weights_desc_init(&data,
convert_to_c(algorithm), &src_desc.data,
&diff_weights_desc.data, nullptr,
&diff_dst_desc.data, &strides[0], &dilates[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a dilated deconvolution "
"weights gradient primitive");
}
};
/// Primitive descriptor for a deconvolution weights gradient primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a deconvolution weights
/// update primitive.
///
/// @param desc descriptor for a deconvolution weights gradient
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case
/// an empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const deconvolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a deconvolution weights
/// update primitive.
///
/// @param desc descriptor for a deconvolution weights gradient
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a deconvolution forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const deconvolution_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a deconvolution weights
/// gradient primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a deconvolution weights
/// gradient primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
dnnl::prop_kind::backward_weights) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
memory::desc diff_weights_desc() const {
return base::diff_weights_desc(0);
}
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
memory::desc diff_bias_desc() const {
return base::diff_weights_desc(1);
}
};
/// Default constructor. Produces an empty object.
deconvolution_backward_weights() = default;
/// Constructs a deconvolution weights gradient primitive.
/// @param pd Primitive descriptor for a deconvolution weights gradient
/// primitive.
deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_deconvolution
/// @addtogroup dnnl_api_lrn LRN
///
/// A primitive to perform local response normalization (LRN) across or within
/// channels.
///
/// @sa @ref dev_guide_lrn in developer guide
///
/// @{
/// Local response normalization (LRN) forward propagation primitive.
struct lrn_forward : public primitive {
/// Descriptor for an LRN forward propagation primitive.
struct desc {
dnnl_lrn_desc_t data;
/// Constructs a descriptor for a LRN forward propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p alg_kind = #dnnl::algorithm::pooling_max and @p
/// prop_kind = #dnnl::prop_kind::forward_training; must be
/// queried for using @ref dnnl::primitive_desc_base::query_md()
/// after a corresponding primitive descriptor is created
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm LRN algorithm kind: either
/// #dnnl::algorithm::lrn_across_channels, or
/// #dnnl::algorithm::lrn_within_channel.
/// @param data_desc Source and destination memory descriptors.
/// @param local_size Regularization local size.
/// @param alpha The alpha regularization parameter.
/// @param beta The beta regularization parameter.
/// @param k The k regularization parameter.
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &data_desc, memory::dim local_size,
float alpha, float beta, float k = 1.f) {
error::wrap_c_api(dnnl_lrn_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &data_desc.data,
local_size, alpha, beta, k),
"could not create a descriptor for a lrn forward "
"propagation primitive");
}
};
/// Primitive descriptor for an LRN forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an LRN forward propagation
/// primitive.
///
/// @param desc Descriptor for an LRN forward propagation primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an LRN forward propagation
/// primitive.
///
/// @param desc Descriptor for an LRN forward propagation primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an LRN forward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for an LRN forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
};
/// Default constructor. Produces an empty object.
lrn_forward() = default;
/// Constructs an LRN forward propagation primitive.
/// @param pd Primitive descriptor for an LRN forward propagation
/// primitive.
lrn_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Local response normalization (LRN) backward propagation primitive.
struct lrn_backward : public primitive {
/// Descriptor for an LRN backward propagation primitive.
struct desc {
dnnl_lrn_desc_t data;
/// Constructs a descriptor for an LRN backward propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if the underlying implementation requires it; must be queried
/// for using @ref dnnl_primitive_desc_query_md() after a
/// corresponding primitive descriptor is created
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param algorithm LRN algorithm kind: either
/// #dnnl::algorithm::lrn_across_channels, or
/// #dnnl::algorithm::lrn_within_channel.
/// @param diff_data_desc Diff source and diff destination memory
/// descriptor.
/// @param data_desc Source memory descriptor.
/// @param local_size Regularization local size.
/// @param alpha The alpha regularization parameter.
/// @param beta The beta regularization parameter.
/// @param k The k regularization parameter.
desc(algorithm algorithm, const memory::desc &data_desc,
const memory::desc &diff_data_desc, memory::dim local_size,
float alpha, float beta, float k = 1.f) {
error::wrap_c_api(
dnnl_lrn_backward_desc_init(&data, convert_to_c(algorithm),
&diff_data_desc.data, &data_desc.data, local_size,
alpha, beta, k),
"could not create a descriptor for a lrn backward "
"propagation primitive");
}
};
/// Primitive descriptor for an LRN backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an LRN backward propagation
/// primitive.
///
/// @param desc Descriptor for an LRN backward propagation primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an LRN forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const lrn_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an LRN backward propagation
/// primitive.
///
/// @param desc Descriptor for an LRN backward propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an LRN forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const lrn_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for an LRN backward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for an LRN backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
};
/// Default constructor. Produces an empty object.
lrn_backward() = default;
/// Constructs an LRN backward propagation primitive.
/// @param pd Primitive descriptor for an LRN backward propagation
/// primitive.
lrn_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_lrn
/// @addtogroup dnnl_api_pooling Pooling
///
/// A primitive to perform max or average pooling.
///
/// @sa @ref dev_guide_pooling in developer guide
///
/// @{
/// Pooling forward propagation primitive.
struct pooling_forward : public primitive {
/// Descriptor for a pooling forward propagation primitive.
struct desc {
dnnl_pooling_desc_t data;
/// Constructs a descriptor for pooling forward propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p alg_kind = #dnnl::algorithm::pooling_max and @p
/// prop_kind = #dnnl::prop_kind::forward_training; must be
/// queried for using @ref dnnl::primitive_desc_base::query_md()
/// after a corresponding primitive descriptor is created
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Pooling algorithm kind: either
/// #dnnl::algorithm::pooling_max,
/// #dnnl::algorithm::pooling_avg_include_padding,
/// or #dnnl::algorithm::pooling_avg (same as
/// #dnnl::algorithm::pooling_avg_exclude_padding).
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param strides Vector of strides for spatial dimension.
/// @param kernel Vector of kernel spatial dimensions.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &dst_desc,
const memory::dims &strides, const memory::dims &kernel,
const memory::dims &padding_l, const memory::dims &padding_r) {
memory::validate_dims(strides, src_desc.data.ndims - 2);
memory::validate_dims(kernel, src_desc.data.ndims - 2);
memory::validate_dims(padding_l, src_desc.data.ndims - 2);
memory::validate_dims(padding_r, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_pooling_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &src_desc.data,
&dst_desc.data, &strides[0], &kernel[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a pooling forward "
"propagation primitive");
}
};
/// Primitive descriptor for a pooling forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a pooling forward
/// propagation primitive.
///
/// @param desc Descriptor for a pooling forward propagation primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a pooling forward
/// propagation primitive.
///
/// @param desc Descriptor for a pooling forward propagation primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a pooling forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a pooling forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
};
/// Default constructor. Produces an empty object.
pooling_forward() = default;
/// Constructs a pooling forward propagation primitive.
/// @param pd Primitive descriptor for a pooling forward propagation
/// primitive.
pooling_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Pooling backward propagation primitive.
struct pooling_backward : public primitive {
/// Descriptor for a pooling backward propagation primitive.
struct desc {
dnnl_pooling_desc_t data;
/// Constructs a descriptor for pooling backward propagation primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p alg_kind = #dnnl::algorithm::pooling_max; must be
/// queried for using @ref dnnl::primitive_desc_base::query_md()
/// after a corresponding primitive descriptor is created
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param algorithm Pooling algorithm kind: either
/// #dnnl::algorithm::pooling_max,
/// #dnnl::algorithm::pooling_avg_include_padding,
/// or #dnnl::algorithm::pooling_avg (same as
/// #dnnl::algorithm::pooling_avg_exclude_padding).
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
/// @param strides Vector of strides for spatial dimension.
/// @param kernel Vector of kernel spatial dimensions.
/// @param padding_l Vector of padding values for low indices for each
/// spatial dimension (front, top, left).
/// @param padding_r Vector of padding values for high indices for
/// each spatial dimension (back, bottom, right).
desc(algorithm algorithm, const memory::desc &diff_src_desc,
const memory::desc &diff_dst_desc, const memory::dims &strides,
const memory::dims &kernel, const memory::dims &padding_l,
const memory::dims &padding_r) {
memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
error::wrap_c_api(
dnnl_pooling_backward_desc_init(&data,
convert_to_c(algorithm), &diff_src_desc.data,
&diff_dst_desc.data, &strides[0], &kernel[0],
&padding_l[0], &padding_r[0]),
"could not create a descriptor for a pooling backward "
"propagation primitive");
}
};
/// Primitive descriptor for a pooling backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a pooling backward
/// propagation primitive.
///
/// @param desc Descriptor for a pooling backward propagation primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a pooling forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const pooling_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a pooling backward
/// propagation primitive.
///
/// @param desc Descriptor for a pooling backward propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a pooling forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const pooling_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a pooling backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a pooling backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
};
/// Default constructor. Produces an empty object.
pooling_backward() = default;
/// Constructs a pooling backward propagation primitive.
/// @param pd Primitive descriptor for a pooling backward propagation
/// primitive.
pooling_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_pooling
/// @addtogroup dnnl_api_eltwise Eltwise
///
/// A primitive to perform elementwise operations such as the
/// rectifier linear unit (ReLU).
///
/// Both forward and backward propagation primitives support in-place
/// operation; that is, src and dst can refer to the same memory for forward
/// propagation, and diff_dst and diff_src can refer to the same memory for
/// backward propagation.
///
/// @warning
/// Because the original source data is required for backward propagation,
/// in-place forward propagation is not generally supported in the
/// training mode. However, for algorithms supporting destination as input
/// memory, dst can be used for the backward propagation, which makes it
/// possible to get performance benefit even in the training mode.
///
/// @sa @ref dev_guide_eltwise in developer guide
///
/// @{
/// Elementwise unary operation forward propagation primitive.
struct eltwise_forward : public primitive {
/// Descriptor for an elementwise forward propagation primitive.
struct desc {
dnnl_eltwise_desc_t data;
/// Constructs a descriptor for an elementwise forward propagation
/// primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm Elementwise algorithm kind.
/// @param data_desc Source and destination memory descriptors.
/// @param alpha The alpha parameter for the elementwise operation.
/// Specific meaning depends on the algorithm.
/// @param beta The beta parameter for the elementwise operation.
/// Specific meaning depends on the algorithm.
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &data_desc, float alpha = 0,
float beta = 0) {
error::wrap_c_api(dnnl_eltwise_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(algorithm),
&data_desc.data, alpha, beta),
"could not create a descriptor for an eltwise forward "
"propagation primitive");
}
};
/// Primitive descriptor for an elementwise forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an elementwise forward
/// propagation primitive.
///
/// @param desc Descriptor for an elementwise forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an elementwise forward
/// propagation primitive.
///
/// @param desc Descriptor for an elementwise forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an eltwise forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for an eltwise forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
eltwise_forward() = default;
/// Constructs an eltwise forward propagation primitive.
/// @param pd Primitive descriptor for an eltwise forward propagation
/// primitive.
eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Elementwise unary operation backward propagation primitive.
struct eltwise_backward : public primitive {
/// Descriptor for an elementwise backward propagation primitive.
struct desc {
dnnl_eltwise_desc_t data;
/// Constructs a descriptor for an elementwise backward propagation
/// primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param algorithm Elementwise algorithm kind.
/// @param diff_data_desc Diff source and destination memory
/// descriptors.
/// @param data_desc Source memory descriptor.
/// @param alpha The alpha parameter for the elementwise operation.
/// Specific meaning depends on the algorithm.
/// @param beta The beta parameter for the elementwise operation.
/// Specific meaning depends on the algorithm.
desc(algorithm algorithm, const memory::desc &diff_data_desc,
const memory::desc &data_desc, float alpha = 0,
float beta = 0) {
error::wrap_c_api(
dnnl_eltwise_backward_desc_init(&data,
dnnl::convert_to_c(algorithm), &diff_data_desc.data,
&data_desc.data, alpha, beta),
"could not create a descriptor for an eltwise backward "
"propagation primitive");
}
};
/// Primitive descriptor for eltwise backward propagation.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an elementwise backward
/// propagation primitive.
///
/// @param desc Descriptor for an elementwise backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an elementwise forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const eltwise_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an elementwise backward
/// propagation primitive.
///
/// @param desc Descriptor for an elementwise backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an elementwise forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const eltwise_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for an eltwise backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for an eltwise backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
eltwise_backward() = default;
/// Constructs an eltwise backward propagation primitive.
/// @param pd Primitive descriptor for an eltwise backward propagation
/// primitive.
eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_eltwise
/// @addtogroup dnnl_api_softmax Softmax
///
/// A primitive to perform softmax.
///
/// @sa @ref dev_guide_softmax in developer guide
///
/// @{
/// Softmax forward propagation primitive.
struct softmax_forward : public primitive {
/// Descriptor for a softmax forward propagation primitive.
struct desc {
dnnl_softmax_desc_t data;
/// Default constructor. Produces an empty object.
desc() = default;
/// Constructs a descriptor for a softmax forward propagation
/// primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param data_desc Source and destination memory descriptor.
/// @param softmax_axis Axis over which softmax is computed.
desc(prop_kind prop_kind, const memory::desc &data_desc,
int softmax_axis) {
error::wrap_c_api(dnnl_softmax_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
&data_desc.data, softmax_axis),
"could not create a descriptor for a softmax forward "
"propagation primitive");
}
};
/// Primitive descriptor for a softmax forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a softmax forward
/// propagation primitive.
///
/// @param desc descriptor for a softmax forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a softmax forward
/// propagation primitive.
///
/// @param desc Descriptor for a softmax forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a softmax forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a softmax forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
softmax_forward() = default;
/// Constructs a softmax forward propagation primitive.
/// @param pd Primitive descriptor for a softmax forward propagation
/// primitive.
softmax_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Softmax backward propagation primitive.
struct softmax_backward : public primitive {
/// Descriptor for a softmax backward propagation primitive.
struct desc {
dnnl_softmax_desc_t data;
/// Default constructor. Produces an empty object.
desc() = default;
/// Constructs a descriptor for a softmax backward propagation
/// primitive.
///
/// Inputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param diff_data_desc Diff source and diff destination memory
/// descriptor.
/// @param data_desc Destination memory descriptor.
/// @param softmax_axis Axis over which softmax is computed.
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
int softmax_axis) {
error::wrap_c_api(
dnnl_softmax_backward_desc_init(&data, &diff_data_desc.data,
&data_desc.data, softmax_axis),
"could not create a descriptor for a softmax backward "
"propagation primitive");
}
};
/// Primitive descriptor for a softmax backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a softmax backward
/// propagation primitive.
///
/// @param desc Descriptor for a softmax backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a softmax forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const softmax_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a softmax backward
/// propagation primitive.
///
/// @param desc Descriptor for a softmax backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a softmax forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const softmax_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a softmax backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a softmax backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
softmax_backward() = default;
/// Constructs a softmax backward propagation primitive.
/// @param pd Primitive descriptor for a softmax backward propagation
/// primitive.
softmax_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_softmax
/// @addtogroup dnnl_api_logsoftmax LogSoftmax
///
/// A primitive to perform logsoftmax.
///
/// @sa @ref dev_guide_logsoftmax in developer guide
///
/// @{
/// Logsoftmax forward propagation primitive.
struct logsoftmax_forward : public primitive {
/// Descriptor for a logsoftmax forward propagation primitive.
struct desc {
dnnl_logsoftmax_desc_t data;
/// Default constructor. Produces an empty object.
desc() = default;
/// Constructs a descriptor for a logsoftmax forward propagation
/// primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param data_desc Source and destination memory descriptor.
/// @param logsoftmax_axis Axis over which softmax is computed.
desc(prop_kind prop_kind, const memory::desc &data_desc,
int logsoftmax_axis) {
error::wrap_c_api(dnnl_logsoftmax_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
&data_desc.data, logsoftmax_axis),
"could not create a descriptor for a logsoftmax forward "
"propagation primitive");
}
};
/// Primitive descriptor for a logsoftmax forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a logsoftmax forward
/// propagation primitive.
///
/// @param desc descriptor for a logsoftmax forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a logsoftmax forward
/// propagation primitive.
///
/// @param desc Descriptor for a logsoftmax forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a logsoftmax forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a logsoftmax forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd,
// Logsoftmax and softmax share the implementation and
// currently report the same primitive kind. Hence this
// must be softmax and not logsoftmax.
dnnl::primitive::kind::softmax,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
logsoftmax_forward() = default;
/// Constructs a logsoftmax forward propagation primitive.
/// @param pd Primitive descriptor for a logsoftmax forward propagation
/// primitive.
logsoftmax_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Logsoftmax backward propagation primitive.
struct logsoftmax_backward : public primitive {
/// Descriptor for a logsoftmax backward propagation primitive.
struct desc {
dnnl_logsoftmax_desc_t data;
/// Default constructor. Produces an empty object.
desc() = default;
/// Constructs a descriptor for a logsoftmax backward propagation
/// primitive.
///
/// Inputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param diff_data_desc Diff source and diff destination memory
/// descriptors.
/// @param data_desc Destination memory descriptor.
/// @param logsoftmax_axis Axis over which softmax is computed.
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
int logsoftmax_axis) {
error::wrap_c_api(dnnl_logsoftmax_backward_desc_init(&data,
&diff_data_desc.data, &data_desc.data,
logsoftmax_axis),
"could not create a descriptor for a logsoftmax backward "
"propagation primitive");
}
};
/// Primitive descriptor for a logsoftmax backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a logsoftmax backward
/// propagation primitive.
///
/// @param desc Descriptor for a logsoftmax backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a logsoftmax forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const logsoftmax_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a logsoftmax backward
/// propagation primitive.
///
/// @param desc Descriptor for a logsoftmax backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a logsoftmax forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const logsoftmax_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a logsoftmax backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a logsoftmax backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd,
// Logsoftmax and softmax share the implementation and
// currently report the same primitive kind. Hence this
// must be softmax and not logsoftmax.
dnnl::primitive::kind::softmax,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
logsoftmax_backward() = default;
/// Constructs a logsoftmax backward propagation primitive.
/// @param pd Primitive descriptor for a logsoftmax backward propagation
/// primitive.
logsoftmax_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_logsoftmax
/// @addtogroup dnnl_api_batch_normalization Batch Normalization
///
/// A primitive to perform batch normalization.
///
/// Both forward and backward propagation primitives support in-place
/// operation; that is, src and dst can refer to the same memory for forward
/// propagation, and diff_dst and diff_src can refer to the same memory for
/// backward propagation.
///
/// The batch normalization primitives computations can be controlled by
/// specifying different @ref dnnl::normalization_flags values. For example,
/// batch normalization can compute the mean and variance on its own or take
/// them as inputs. It can either perform scaling and shifting using gamma
/// and beta parameters or not. Optionally, it can also perform a fused ReLU,
/// which in case of training would also require a workspace.
///
/// @sa @ref dev_guide_batch_normalization in developer guide
///
/// @{
/// Batch normalization forward propagation primitive.
struct batch_normalization_forward : public primitive {
/// Descriptor for a batch normalization forward propagation primitive.
struct desc {
dnnl_batch_normalization_desc_t data;
/// Constructs a batch normalization descriptor for forward
/// propagation.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// set in @p flags
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// set in @p flags
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is set
/// in @p flags
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::dst_desc(`1`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// not set in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training
/// - `variance` (#dnnl::primitive_desc_base::dst_desc(`2`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// not set in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if #dnnl::normalization_flags::fuse_norm_relu bit-flag is set
/// in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training; must be queried
/// for using @ref primitive_desc_base::query_md() after a
/// corresponding primitive descriptor is created
///
/// @note
/// In-place operation is supported: the dst can refer to the same
/// memory as the src.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training and
/// #dnnl::prop_kind::forward_inference.
/// @param data_desc Source and destination memory descriptors.
/// @param epsilon Batch normalization epsilon parameter.
/// @param flags Batch normalization flags (@ref
/// dnnl::normalization_flags).
desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon,
normalization_flags flags) {
error::wrap_c_api(
dnnl_batch_normalization_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind), &data_desc.data,
epsilon, convert_to_c(flags)),
"could not create a descriptor for a batch normalization "
"forward propagation primitive");
}
};
/// Primitive descriptor for a batch normalization forward propagation
/// primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a batch normalization forward
/// propagation primitive.
///
/// @param desc Descriptor for a batch normalization forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a batch normalization forward
/// propagation primitive.
///
/// @param desc Descriptor for a batch normalization forward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a batch normalization
/// forward propagation primitive from a C API primitive descriptor
/// that must have a matching kind.
///
/// @param pd C API primitive descriptor for a batch normalization
/// forward propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd,
dnnl::primitive::kind::batch_normalization,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
/// Returns memory descriptor for mean.
/// @returns Memory descriptor for mean.
memory::desc mean_desc() const { return stat_desc(mean); }
/// Returns memory descriptor for variance.
/// @returns Memory descriptor for variance.
memory::desc variance_desc() const { return stat_desc(var); }
private:
enum {
mean = 1,
var = 2,
};
memory::desc stat_desc(int kind) const {
dnnl_batch_normalization_desc_t *p;
error::wrap_c_api(
dnnl_primitive_desc_query(get(),
dnnl::convert_to_c(query::batch_normalization_d), 0,
&p),
"could not retrieve a descriptor from a primitive "
"descriptor for batch normalization forward propagation "
"primitive");
return query_md(p->flags & dnnl_use_global_stats ? query::src_md
: query::dst_md,
kind);
}
};
/// Default constructor. Produces an empty object.
batch_normalization_forward() = default;
/// Constructs a batch normalization forward propagation primitive.
/// @param pd Primitive descriptor for a batch normalization forward
/// propagation primitive.
batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Batch normalization backward propagation primitive.
struct batch_normalization_backward : public primitive {
/// Descriptor for a batch normalization backward propagation primitive.
struct desc {
dnnl_batch_normalization_desc_t data;
/// Constructs a batch normalization descriptor for backward
/// propagation.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`))
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
/// set in @p flags
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if #dnnl::normalization_flags::fuse_norm_relu bit-flag is set
/// in @p flags
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_scale_and_shift`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
/// set in @p flags and @p prop_kind = #dnnl::prop_kind::backward
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
/// (diffs for all parameters are computed in this case).
/// @param diff_data_desc Diff source and diff destination memory
/// descriptor.
/// @param data_desc Source memory descriptor.
/// @param epsilon Batch normalization epsilon parameter.
/// @param flags Batch normalization flags (@ref
/// dnnl::normalization_flags).
desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
const memory::desc &data_desc, float epsilon,
normalization_flags flags) {
error::wrap_c_api(
dnnl_batch_normalization_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind), &diff_data_desc.data,
&data_desc.data, epsilon, convert_to_c(flags)),
"could not create a descriptor for a batch normalization "
"backward propagation primitive");
}
};
/// Primitive descriptor for a batch normalization backward propagation
/// primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a batch normalization backward
/// propagation primitive.
///
/// @param desc Descriptor for a batch normalization backward
/// propagation primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a batch normalization
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const batch_normalization_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a batch normalization backward
/// propagation primitive.
///
/// @param desc Descriptor for a batch normalization backward
/// propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a batch normalization
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const batch_normalization_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a batch normalization
/// backward propagation primitive from a C API primitive descriptor
/// that must have a matching kind.
///
/// @param pd C API primitive descriptor for a batch normalization
/// backward propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd,
dnnl::primitive::kind::batch_normalization,
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
memory::desc diff_weights_desc() const {
return base::diff_weights_desc(0);
}
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
memory::desc variance_desc() const {
return query_md(query::src_md, 2);
}
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
};
/// Default constructor. Produces an empty object.
batch_normalization_backward() = default;
/// Constructs a batch normalization backward propagation primitive.
/// @param pd Primitive descriptor for a batch normalization backward
/// propagation primitive.
batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_batch_normalization
/// @addtogroup dnnl_api_layer_normalization Layer Normalization
///
/// A primitive to perform layer normalization. Normalization is performed
/// within the last logical dimension of data tensor.
///
/// Both forward and backward propagation primitives support in-place
/// operation; that is, src and dst can refer to the same memory for forward
/// propagation, and diff_dst and diff_src can refer to the same memory for
/// backward propagation.
///
/// The layer normalization primitives computations can be controlled by
/// specifying different dnnl::normalization_flags values. For example,
/// layer normalization forward propagation can be configured to either
/// compute the mean and variance or take them as arguments. It can either
/// perform scaling and shifting using gamma and beta parameters or not.
/// Optionally, it can also perform a fused ReLU, which in case of training
/// would also require a workspace.
///
/// @sa @ref dev_guide_layer_normalization in developer guide
///
/// @{
/// Layer normalization forward propagation primitive.
struct layer_normalization_forward : public primitive {
/// Descriptor for a layer normalization forward propagation primitive.
struct desc {
dnnl_layer_normalization_desc_t data;
/// Constructs a descriptor for layer normalization forward
/// propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// set in @p flags
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// set in @p flags
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is set
/// in @p flags
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::dst_desc(`1`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// not set in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training
/// - `variance` (#dnnl::primitive_desc_base::dst_desc(`2`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// not set in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param data_desc Source and destination memory descriptor.
/// @param stat_desc Statistics memory descriptors.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref
/// dnnl::normalization_flags).
desc(prop_kind prop_kind, const memory::desc &data_desc,
const memory::desc &stat_desc, float epsilon,
normalization_flags flags) {
error::wrap_c_api(
dnnl_layer_normalization_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind), &data_desc.data,
&stat_desc.data, epsilon, convert_to_c(flags)),
"could not create a descriptor for a layer normalization "
"forward propagation primitive");
}
/// Constructs a descriptor for layer normalization forward
/// propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// set in @p flags
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// set in @p flags
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is set
/// in @p flags
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::dst_desc(`1`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// not set in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training
/// - `variance` (#dnnl::primitive_desc_base::dst_desc(`2`)),
/// if #dnnl::normalization_flags::use_global_stats bit-flag is
/// not set in @p flags and @p prop_kind =
/// #dnnl::prop_kind::forward_training
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param data_desc Source and destination memory descriptor.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref
/// dnnl::normalization_flags).
desc(prop_kind prop_kind, const memory::desc &data_desc, float epsilon,
normalization_flags flags) {
error::wrap_c_api(
dnnl_layer_normalization_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind), &data_desc.data,
nullptr, epsilon, convert_to_c(flags)),
"could not create a descriptor for a layer normalization "
"forward propagation primitive");
}
};
/// Primitive descriptor for a layer normalization forward propagation
/// primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a layer normalization forward
/// propagation primitive.
///
/// @param desc Descriptor for a layer normalization forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a layer normalization forward
/// propagation primitive.
///
/// @param desc Descriptor for a layer normalization forward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a layer normalization
/// forward propagation primitive from a C API primitive descriptor
/// that must have a matching kind.
///
/// @param pd C API primitive descriptor for a layer normalization
/// forward propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd,
dnnl::primitive::kind::layer_normalization,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
memory::desc mean_desc() const { return stat_desc(mean); }
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
memory::desc variance_desc() const { return stat_desc(var); }
private:
enum {
mean = 1,
var = 2,
};
memory::desc stat_desc(int kind) const {
dnnl_layer_normalization_desc_t *p;
error::wrap_c_api(
dnnl_primitive_desc_query(get(),
dnnl::convert_to_c(query::layer_normalization_d), 0,
&p),
"could not retrieve a descriptor from a primitive "
"descriptor for layer normalization forward propagation "
"primitive");
return query_md(p->flags & dnnl_use_global_stats ? query::src_md
: query::dst_md,
kind);
}
};
/// Default constructor. Produces an empty object.
layer_normalization_forward() = default;
/// Constructs a layer normalization forward propagation primitive.
/// @param pd Primitive descriptor for a layer normalization forward
/// propagation primitive.
layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Layer normalization backward propagation primitive.
struct layer_normalization_backward : public primitive {
/// Descriptor for a layer normalization backward propagation primitive.
struct desc {
dnnl_layer_normalization_desc_t data;
/// Constructs a descriptor for layer normalization backward
/// propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`))
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
/// set in @p flags
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_scale_and_shift`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`)), if
/// #dnnl::normalization_flags::use_scale_shift bit-flag is set
/// in @p flags and @p prop_kind = #dnnl::prop_kind::backward
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
/// (diffs for all parameters are computed in this case).
/// @param diff_data_desc Diff source and diff destination memory
/// descriptor.
/// @param data_desc Source memory descriptor.
/// @param stat_desc Statistics memory descriptors.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref
/// dnnl::normalization_flags).
desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
const memory::desc &data_desc, const memory::desc &stat_desc,
float epsilon, normalization_flags flags) {
error::wrap_c_api(
dnnl_layer_normalization_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind), &diff_data_desc.data,
&data_desc.data, &stat_desc.data, epsilon,
convert_to_c(flags)),
"could not create a descriptor for a batch normalization "
"backward propagation primitive");
}
/// Constructs a descriptor for layer normalization backward
/// propagation primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `mean` (#dnnl::primitive_desc_base::src_desc(`1`))
/// - `variance` (#dnnl::primitive_desc_base::src_desc(`2`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `scale_and_shift` (#dnnl::primitive_desc_base::weights_desc(`0`)),
/// if #dnnl::normalization_flags::use_scale_shift bit-flag is
/// set in @p flags
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_scale_and_shift`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`)), if
/// #dnnl::normalization_flags::use_scale_shift bit-flag is set
/// in @p flags and @p prop_kind = #dnnl::prop_kind::backward
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
/// (diffs for all parameters are computed in this case).
/// @param diff_data_desc Diff source and diff destination memory
/// descriptor.
/// @param data_desc Source memory descriptor.
/// @param epsilon Layer normalization epsilon parameter.
/// @param flags Layer normalization flags (@ref
/// dnnl::normalization_flags).
desc(prop_kind prop_kind, const memory::desc &diff_data_desc,
const memory::desc &data_desc, float epsilon,
normalization_flags flags) {
error::wrap_c_api(dnnl_layer_normalization_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
&diff_data_desc.data, &data_desc.data,
nullptr, epsilon, convert_to_c(flags)),
"could not create a descriptor for a batch normalization "
"backward propagation primitive");
}
};
/// Primitive descriptor for a layer normalization backward propagation
/// primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a layer normalization backward
/// propagation primitive.
///
/// @param desc Descriptor for a layer normalization backward
/// propagation primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a layer normalization backward
/// propagation primitive.
///
/// @param desc Descriptor for a layer normalization backward
/// propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a layer normalization
/// backward propagation primitive from a C API primitive descriptor
/// that must have a matching kind.
///
/// @param pd C API primitive descriptor for a layer normalization
/// backward propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd,
dnnl::primitive::kind::layer_normalization,
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
memory::desc diff_weights_desc() const {
return base::diff_weights_desc(0);
}
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
memory::desc variance_desc() const {
return query_md(query::src_md, 2);
}
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const { return base::workspace_desc(); }
};
/// Default constructor. Produces an empty object.
layer_normalization_backward() = default;
/// Constructs a layer normalization backward propagation primitive.
/// @param pd Primitive descriptor for a layer normalization backward
/// propagation primitive.
layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_layer_normalization
/// @addtogroup dnnl_api_inner_product Inner Product
///
/// A primitive to compute an inner product.
///
/// @sa @ref dev_guide_inner_product in developer guide
///
/// @{
/// Inner product forward propagation primitive.
struct inner_product_forward : public primitive {
/// Descriptor for an inner product forward propagation primitive.
struct desc {
dnnl_inner_product_desc_t data;
/// Constructs a descriptor for an inner product forward propagation
/// primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param src_desc Memory descriptor for src.
/// @param weights_desc Memory descriptor for diff weights.
/// @param bias_desc Memory descriptor for diff bias.
/// @param dst_desc Memory descriptor for diff dst.
desc(prop_kind prop_kind, const memory::desc &src_desc,
const memory::desc &weights_desc, const memory::desc &bias_desc,
const memory::desc &dst_desc) {
error::wrap_c_api(dnnl_inner_product_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
&src_desc.data, &weights_desc.data,
&bias_desc.data, &dst_desc.data),
"could not create a descriptor for an inner product "
"forward propagation primitive");
}
/// Constructs a descriptor for an inner product forward propagation
/// primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param src_desc Memory descriptor for src.
/// @param weights_desc Memory descriptor for diff weights.
/// @param dst_desc Memory descriptor for dst.
desc(prop_kind prop_kind, const memory::desc &src_desc,
const memory::desc &weights_desc,
const memory::desc &dst_desc) {
error::wrap_c_api(
dnnl_inner_product_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind), &src_desc.data,
&weights_desc.data, nullptr, &dst_desc.data),
"could not create a descriptor for an inner product "
"forward propagation primitive");
}
};
/// Primitive descriptor for an inner product forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an inner product forward
/// propagation primitive.
///
/// @param desc Descriptor for an inner product forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an inner product forward
/// propagation primitive.
///
/// @param desc Descriptor for an inner product forward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an inner product forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for an inner product forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
memory::desc bias_desc() const { return base::weights_desc(1); }
};
/// Default constructor. Produces an empty object.
inner_product_forward() = default;
/// Constructs an inner product forward propagation primitive.
/// @param pd Primitive descriptor for an inner product forward
/// propagation primitive.
inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Inner product backward propagation primitive.
struct inner_product_backward_data : public primitive {
/// Descriptor for an inner product backward propagation primitive.
struct desc {
dnnl_inner_product_desc_t data;
/// Constructs a descriptor for an inner product backward propagation
/// primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param diff_src_desc Memory descriptor for diff src.
/// @param weights_desc Memory descriptor for weights.
/// @param diff_dst_desc Memory descriptor for diff dst.
desc(const memory::desc &diff_src_desc,
const memory::desc &weights_desc,
const memory::desc &diff_dst_desc) {
error::wrap_c_api(dnnl_inner_product_backward_data_desc_init(&data,
&diff_src_desc.data, &weights_desc.data,
&diff_dst_desc.data),
"could not create a descriptor for an inner product "
"backward propagation primitive");
}
};
/// Primitive descriptor for an inner product backward propagation
/// primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an inner product backward
/// propagation primitive.
///
/// @param desc Descriptor for an inner product backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an inner product
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const inner_product_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an inner product backward
/// propagation primitive.
///
/// @param desc Descriptor for an inner product backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an inner product
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const inner_product_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for an inner product backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for an inner product backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const { return base::weights_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
inner_product_backward_data() = default;
/// Constructs an inner product backward propagation primitive.
/// @param pd Primitive descriptor for an inner product backward
/// propagation primitive.
inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
};
/// Inner product weights gradient primitive.
struct inner_product_backward_weights : public primitive {
/// Descriptor for an inner product weights gradient primitive.
struct desc {
dnnl_inner_product_desc_t data;
/// Constructs a descriptor for an inner product descriptor weights
/// update primitive with bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param src_desc Memory descriptor for src.
/// @param diff_weights_desc Memory descriptor for diff weights.
/// @param diff_bias_desc Memory descriptor for diff bias.
/// @param diff_dst_desc Memory descriptor for diff dst.
desc(const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_desc) {
error::wrap_c_api(
dnnl_inner_product_backward_weights_desc_init(&data,
&src_desc.data, &diff_weights_desc.data,
&diff_bias_desc.data, &diff_dst_desc.data),
"could not create a descriptor for an inner product "
"weights gradient primitive");
}
/// Constructs a descriptor for an inner product descriptor weights
/// update primitive without bias.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_weights` (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param src_desc Memory descriptor for src.
/// @param diff_weights_desc Memory descriptor for diff weights.
/// @param diff_dst_desc Memory descriptor for diff dst.
desc(const memory::desc &src_desc,
const memory::desc &diff_weights_desc,
const memory::desc &diff_dst_desc) {
error::wrap_c_api(
dnnl_inner_product_backward_weights_desc_init(&data,
&src_desc.data, &diff_weights_desc.data, nullptr,
&diff_dst_desc.data),
"could not create a descriptor for an inner product "
"weights gradient primitive");
}
};
/// Primitive descriptor for an inner product weights gradient primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an inner product weights
/// update primitive.
///
/// @param desc Descriptor for an inner product weights gradient
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an inner product
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const inner_product_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an inner product weights
/// update primitive.
///
/// @param desc Descriptor for an inner product weights gradient
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an inner product
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const inner_product_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for an inner product weights
/// update primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for an inner product weights
/// gradient primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
dnnl::prop_kind::backward_weights) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
memory::desc diff_weights_desc() const {
return base::diff_weights_desc(0);
}
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
/// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
memory::desc diff_bias_desc() const {
return base::diff_weights_desc(1);
}
};
/// Default constructor. Produces an empty object.
inner_product_backward_weights() = default;
/// Constructs an inner product weights gradient primitive.
/// @param pd Primitive descriptor for an inner product weights gradient
/// primitive.
inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_inner_product
/// @addtogroup dnnl_api_rnn RNN
///
/// A primitive to compute recurrent neural network layers.
///
/// @sa @ref dev_guide_rnn in developer guide
///
/// @{
/// Base class for primitive descriptors for RNN primitives.
struct rnn_primitive_desc_base : public primitive_desc {
using primitive_desc::primitive_desc;
/// Default constructor. Produces an empty object.
rnn_primitive_desc_base() = default;
/// Constructs an RNN primitive descriptor base from a C API primitive
/// descriptor while checking that it actually describes the expected
/// primitive by comparing propagation and primitive kinds.
///
/// @param pd C API primitive descriptor.
/// @param prop_kind Expected propagation kind.
/// @param cell_kind Expected cell kind.
rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind prop_kind,
dnnl::algorithm cell_kind)
: rnn_primitive_desc_base(pd, prop_kind, prop_kind, cell_kind) {}
/// Returns source layer memory descriptor.
/// @returns Source layer memory descriptor.
memory::desc src_layer_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
}
/// Returns source iteration memory descriptor.
/// @returns Source iteration memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// source iteration parameter.
memory::desc src_iter_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
}
/// Returns source recurrent cell state memory descriptor.
/// @returns Source recurrent cell state memory descriptor.
memory::desc src_iter_c_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
}
/// Returns weights layer memory descriptor.
/// @returns Weights layer memory descriptor.
memory::desc weights_layer_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
}
/// Returns weights iteration memory descriptor.
/// @returns Weights iteration memory descriptor.
memory::desc weights_iter_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
}
/// Returns weights peephole memory descriptor.
/// @returns Weights peephole memory descriptor.
memory::desc weights_peephole_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
}
/// Returns weights projection memory descriptor.
/// @returns Weights projection memory descriptor.
memory::desc weights_projection_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
}
/// Returns bias memory descriptor.
/// @returns Bias memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// bias parameter.
memory::desc bias_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
}
/// Returns destination layer memory descriptor.
/// @returns Destination layer memory descriptor.
memory::desc dst_layer_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
}
/// Returns destination iteration memory descriptor.
/// @returns Destination iteration memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// destination iteration parameter.
memory::desc dst_iter_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
}
/// Returns destination recurrent cell state memory descriptor.
/// @returns Destination recurrent cell state memory descriptor.
memory::desc dst_iter_c_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
}
/// Returns diff source layer memory descriptor.
/// @returns Diff source layer memory descriptor.
memory::desc diff_src_layer_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
}
/// Returns diff source iteration memory descriptor.
/// @returns Diff source iteration memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff source iteration parameter.
memory::desc diff_src_iter_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
}
/// Returns diff source recurrent cell state memory descriptor.
/// @returns Diff source recurrent cell state memory descriptor.
memory::desc diff_src_iter_c_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
}
/// Returns diff weights layer memory descriptor.
/// @returns Diff weights layer memory descriptor.
memory::desc diff_weights_layer_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
}
/// Returns diff weights iteration memory descriptor.
/// @returns Diff weights iteration memory descriptor.
memory::desc diff_weights_iter_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
}
/// Returns diff weights peephole memory descriptor.
/// @returns Diff weights peephole memory descriptor.
memory::desc diff_weights_peephole_desc() const {
return base::query_md(
query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
}
/// Returns diff weights projection memory descriptor.
/// @returns Diff weights projection memory descriptor.
memory::desc diff_weights_projection_desc() const {
return base::query_md(
query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
}
/// Returns diff bias memory descriptor.
/// @returns Diff bias memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff bias parameter.
memory::desc diff_bias_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
}
/// Returns diff destination layer memory descriptor.
/// @returns Diff destination layer memory descriptor.
memory::desc diff_dst_layer_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
}
/// Returns diff destination iteration memory descriptor.
/// @returns Diff destination iteration memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff destination iteration parameter.
memory::desc diff_dst_iter_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
}
/// Returns diff destination recurrent cell state memory descriptor.
/// @returns Diff destination recurrent cell state memory descriptor.
memory::desc diff_dst_iter_c_desc() const {
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
}
protected:
using rnn_base = rnn_primitive_desc_base;
// (Deliberately not using doxygen comments)
//
// Constructs an RNN primitive descriptor base from a C API primitive
// descriptor while checking that it actually describes the expected
// primitive by comparing propagation and primitive kinds. Caller can
// pass two options propagation kinds. This is typically used to check
// that propagation kind is inference or training forward propagation.
//
// @param pd C API primitive descriptor.
// @param prop_kind1 Expected propagation kind.
// @param prop_kind2 Expected propagation kind.
// @param cell_kind Expected cell kind.
rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
dnnl::algorithm cell_kind) {
dnnl_rnn_desc_t *rnn_d;
dnnl_status_t rc;
rc = dnnl_primitive_desc_query(pd, dnnl_query_rnn_d, 0, &rnn_d);
error::wrap_c_api(rc,
"could not retrieve a descriptor from a primitive descriptor "
"for an RNN primitive");
dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
bool ok = rnn_d->primitive_kind == dnnl_rnn
&& (rnn_d->prop_kind == c_prop_kind1
|| rnn_d->prop_kind == c_prop_kind2)
&& rnn_d->cell_kind == c_cell_kind;
if (!ok)
DNNL_THROW_ERROR(dnnl_invalid_arguments,
"mismatch between expected and provided descriptors for an "
"RNN primitive");
reset_with_clone(pd);
}
};
/// Vanilla RNN forward propagation primitive.
struct vanilla_rnn_forward : public primitive {
/// Descriptor for a vanilla RNN forward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for a vanilla RNN forward propagation
/// primitive.
///
/// The @p src_iter_desc, @p bias_desc, and @p dst_iter_desc may point
/// to a zero memory descriptor. This would then indicate that the RNN
/// forward propagation primitive should not use them and should
/// default to zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
///
/// Outputs:
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
/// must be queried for using @ref
/// dnnl::primitive_desc_base::query_md() after a corresponding
/// primitive descriptor is created
///
/// @note
/// All memory descriptors except @p src_iter_desc can be
/// initialized with an #dnnl::memory::format_tag::any value of @p
/// format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param activation Activation kind. Possible values are
/// #dnnl::algorithm::eltwise_relu,
/// #dnnl::algorithm::eltwise_tanh, or
/// #dnnl::algorithm::eltwise_logistic.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param flags Unused.
/// @param alpha Negative slope if activation is
/// #dnnl::algorithm::eltwise_relu.
/// @param beta Unused.
desc(prop_kind prop_kind, algorithm activation, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
float beta = 0.0f) {
error::wrap_c_api(
dnnl_vanilla_rnn_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(activation),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &weights_layer_desc.data,
&weights_iter_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
dnnl::convert_to_c(flags), alpha, beta),
"could not create a descriptor for a vanilla RNN forward "
"propagation primitive");
}
};
/// Primitive descriptor for a vanilla RNN forward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a vanilla RNN forward
/// propagation primitive.
///
/// @param desc Descriptor for a vanilla RNN forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a vanilla RNN forward
/// propagation primitive.
///
/// @param desc Descriptor for a vanilla RNN forward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a vanilla RNN forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a vanilla RNN forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference,
dnnl::algorithm::vanilla_rnn) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
};
/// Default constructor. Produces an empty object.
vanilla_rnn_forward() = default;
/// Constructs a vanilla RNN forward propagation primitive.
/// @param pd Primitive descriptor for a vanilla RNN forward
/// propagation primitive.
vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Vanilla RNN backward propagation primitive.
struct vanilla_rnn_backward : public primitive {
/// Vanilla RNN descriptor backward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for a vanilla RNN backward propagation
/// primitive.
///
/// The @p src_iter_desc together with @p diff_src_iter_desc, @p
/// bias_desc together with @p diff_bias_desc, and @p dst_iter_desc
/// together with @p diff_src_iter_desc, may point to a zero memory
/// descriptor. This would then indicate that the RNN backward
/// propagation primitive should not use the respective data and
/// should use zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `diff_dst_iter`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
///
/// Outputs:
/// - `diff_src_layer`
/// (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_src_iter`
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
/// - `diff_weights_layer`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_weights_iter`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
/// - `diff_bias`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
///
/// @note
/// All the memory descriptors may be initialized with the
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Must be
/// #dnnl::prop_kind::backward.
/// @param activation Activation kind. Possible values are
/// #dnnl::algorithm::eltwise_relu,
/// #dnnl::algorithm::eltwise_tanh, or
/// #dnnl::algorithm::eltwise_logistic.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input
/// vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input
/// recurrent hidden state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of
/// weights applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of
/// weights applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of
/// output vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
/// recurrent hidden state vector.
/// @param flags Unused.
/// @param alpha Negative slope if activation is
/// #dnnl::algorithm::eltwise_relu.
/// @param beta Unused.
desc(prop_kind prop_kind, algorithm activation, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &diff_src_layer_desc,
const memory::desc &diff_src_iter_desc,
const memory::desc &diff_weights_layer_desc,
const memory::desc &diff_weights_iter_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_layer_desc,
const memory::desc &diff_dst_iter_desc,
rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
float beta = 0.0f) {
error::wrap_c_api(
dnnl_vanilla_rnn_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(activation),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &weights_layer_desc.data,
&weights_iter_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
&diff_weights_layer_desc.data,
&diff_weights_iter_desc.data, &diff_bias_desc.data,
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
dnnl::convert_to_c(flags), alpha, beta),
"could not create a descriptor for a vanilla RNN backward "
"propagation primitive");
}
};
/// Primitive descriptor for a RNN backward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a vanilla RNN backward
/// propagation primitive.
///
/// @param desc Descriptor for a vanilla RNN backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a vanilla RNN backward
/// propagation primitive.
///
/// @param desc Descriptor for a vanilla RNN backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, &attr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a vanilla RNN backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a vanilla RNN backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
dnnl::algorithm::vanilla_rnn) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
memory::desc diff_src_layer_desc() const {
return rnn_base::diff_src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
memory::desc diff_src_iter_desc() const {
return rnn_base::diff_src_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
memory::desc diff_weights_layer_desc() const {
return rnn_base::diff_weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
memory::desc diff_weights_iter_desc() const {
return rnn_base::diff_weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
memory::desc diff_bias_desc() const {
return rnn_base::diff_bias_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
memory::desc diff_dst_layer_desc() const {
return rnn_base::diff_dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
memory::desc diff_dst_iter_desc() const {
return rnn_base::diff_dst_iter_desc();
}
};
/// Default constructor. Produces an empty object.
vanilla_rnn_backward() = default;
/// Constructs a vanilla RNN backward propagation primitive.
/// @param pd Primitive descriptor for a vanilla RNN backward
/// propagation primitive.
vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// LSTM forward propagation primitive.
struct lstm_forward : public primitive {
/// Descriptor for an LSTM forward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for an LSTM (with or without peephole and
/// with or without projection) forward propagation primitive.
///
/// The @p src_iter_desc, @p src_iter_c_desc, @p weights_peephole_desc,
/// @p bias_desc, @p dst_iter_desc, and @p dst_iter_c_desc may point to
/// a zero memory descriptor. This would then indicate that the LSTM
/// forward propagation primitive should not use them and should
/// default to zero values instead.
///
/// The @p weights_projection_desc may point to a zero memory
/// descriptor. This would then indicate that the LSTM doesn't have
/// recurrent projection layer.
///
/// @note
/// All memory descriptors can be initialized with an
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// Inputs:
/// - src_layer (#dnnl::primitive_desc_base::src_desc (0))
/// - src_iter (#dnnl::primitive_desc_base::src_desc (1)), if used
/// - src_iter_c (#dnnl::primitive_desc_base::src_desc (2)), if used
/// - weights_layer (#dnnl::primitive_desc_base::weights_desc (0))
/// - weights_iter (#dnnl::primitive_desc_base::weights_desc (1))
/// - weights_peephole (#dnnl::primitive_desc_base::weights_desc (2)),
/// if used
/// - weights_projection
/// (#dnnl::primitive_desc_base::weights_desc (index)), if used and
/// index is:
/// - 2, if there is no weights_peephole
/// - 3, otherwise
/// - bias (#dnnl::primitive_desc_base::weights_desc (index)), if used
/// and index is:
/// - 2, if neither weights_peephole nor weights_projection is used
/// - 3, if one of weights_peephole or weights_projection is used
/// - 4, if both weights_peephole and weights_projection are used
///
/// Outputs:
/// - dst_layer (#dnnl::primitive_desc_base::dst_desc (0))
/// - dst_iter (#dnnl::primitive_desc_base::dst_desc (1)), if used
/// - dst_iter_c (#dnnl::primitive_desc_base::dst_desc (2)), if used
/// - workspace (#dnnl::primitive_desc_base::workspace_desc (0)),
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
/// must be queried for using @ref
/// dnnl::primitive_desc_base::query_md() after a corresponding
/// primitive descriptor is created
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent
/// cell state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param weights_peephole_desc Memory descriptor for the weights
/// applied to the cell states (according to the Peephole LSTM
/// formula).
/// @param weights_projection_desc Memory descriptor for the weights
/// applied to the hidden states to get the recurrent projection
/// (according to the Projection LSTM formula).
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
/// cell state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &src_iter_c_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &weights_peephole_desc,
const memory::desc &weights_projection_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &dst_iter_c_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lstm_forward_desc_init_v3(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &src_iter_c_desc.data,
&weights_layer_desc.data, &weights_iter_desc.data,
&weights_peephole_desc.data,
&weights_projection_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
"could not create a descriptor for an LSTM forward "
"propagation primitive");
}
/// Constructs a descriptor for an LSTM (with or without peephole)
/// forward propagation primitive.
///
/// The @p src_iter_desc, @p src_iter_c_desc, @p weights_peephole_desc,
/// @p bias_desc, @p dst_iter_desc, and @p dst_iter_c_desc may point to
/// a zero memory descriptor. This would then indicate that the LSTM
/// forward propagation primitive should not use them and should
/// default to zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `weights_peephole` (#dnnl::primitive_desc_base::weights_desc(`2`)),
/// if used
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used and
/// LSTM is without peephole
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`3`)), if used and
/// LSTM is with peephole
///
/// Outputs:
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
/// must be queried for using @ref
/// dnnl::primitive_desc_base::query_md() after a corresponding
/// primitive descriptor is created
///
/// @note
/// All memory descriptors can be initialized with an
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent
/// cell state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param weights_peephole_desc Memory descriptor for the weights
/// applied to the cell states (according to the Peephole LSTM
/// formula).
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
/// cell state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &src_iter_c_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &weights_peephole_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &dst_iter_c_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lstm_forward_desc_init_v2(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &src_iter_c_desc.data,
&weights_layer_desc.data, &weights_iter_desc.data,
&weights_peephole_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
"could not create a descriptor for an LSTM forward "
"propagation primitive");
}
/// Constructs a descriptor for an LSTM forward propagation primitive.
///
/// The @p src_iter_desc, @p src_iter_c_desc, @p bias_desc, @p
/// dst_iter_desc, and @p dst_iter_c_desc may point to a zero memory
/// descriptor. This would then indicate that the LSTM forward
/// propagation primitive should not use them and should default to
/// zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
///
/// Outputs:
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
/// must be queried for using @ref
/// dnnl::primitive_desc_base::query_md() after a
/// corresponding primitive descriptor is created
///
/// @note
/// All memory descriptors can be initialized with an
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent
/// cell state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
/// cell state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &src_iter_c_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &dst_iter_c_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lstm_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &src_iter_c_desc.data,
&weights_layer_desc.data, &weights_iter_desc.data,
&bias_desc.data, &dst_layer_desc.data,
&dst_iter_desc.data, &dst_iter_c_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for an LSTM forward "
"propagation primitive");
}
};
/// Primitive descriptor for an LSTM forward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an LSTM forward propagation
/// primitive.
///
/// @param desc Descriptor for an LSTM forward propagation primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an LSTM forward propagation
/// primitive.
///
/// @param desc Descriptor for an LSTM forward propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an LSTM forward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for an LSTM forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference,
dnnl::algorithm::vanilla_lstm) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_c_desc() const {
return rnn_base::src_iter_c_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
memory::desc weights_peephole_desc() const {
return rnn_base::weights_peephole_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
memory::desc weights_projection_desc() const {
return rnn_base::weights_projection_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc dst_iter_c_desc() const {
return rnn_base::dst_iter_c_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
};
/// Default constructor. Produces an empty object.
lstm_forward() = default;
/// Constructs an LSTM forward propagation primitive.
/// @param pd Primitive descriptor for an LSTM forward propagation
/// primitive.
lstm_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// LSTM backward propagation primitive.
struct lstm_backward : public primitive {
/// Descriptor for an LSTM backward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs an LSTM (with or without peephole and with or without
/// projection) descriptor for backward propagation using @p prop_kind,
/// @p direction, and memory descriptors.
///
/// The @p src_iter_desc together with @p diff_iter_desc, @p
/// src_iter_c_desc together with @p src_iter_c_desc, @p
/// weights_peephole_desc together with @p diff_weights_peephole_desc,
/// @p bias_desc together with @p diff_bias_desc, @p dst_iter_desc
/// together with @p diff_dst_iter_desc, and @p dst_iter_c_desc
/// together with @p diff_dst_iter_c_desc, may point to a zero memory
/// descriptor. This would then indicate that the LSTM backward
/// propagation primitive should not use them and should default to
/// zero values instead.
///
/// The @p weights_projection_desc together with @p
/// diff_weights_projection_desc may point to a zero memory descriptor.
/// This would then indicate that the LSTM doesn't have recurrent
/// projection layer.
///
/// @note
/// All memory descriptors can be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// Inputs:
/// - src_layer (#dnnl::primitive_desc_base::src_desc (0))
/// - src_iter (#dnnl::primitive_desc_base::src_desc (1)), if used
/// - src_iter_c (#dnnl::primitive_desc_base::src_desc (2)), if used
/// - weights_layer (#dnnl::primitive_desc_base::weights_desc (0))
/// - weights_iter (#dnnl::primitive_desc_base::weights_desc (1))
/// - weights_peephole (#dnnl::primitive_desc_base::weights_desc (2)),
/// if used
/// - weights_projection
/// (#dnnl::primitive_desc_base::weights_desc (index)), if used and
/// index is:
/// - 2, if there is no weights_peephole
/// - 3, otherwise
/// - bias (#dnnl::primitive_desc_base::weights_desc (index)), if used
/// and index is:
/// - 2, if neither weights_peephole nor weights_projection is used
/// - 3, if one of weights_peephole or weights_projection is used
/// - 4, if both weights_peephole and weights_projection are used
/// - dst_layer (#dnnl::primitive_desc_base::dst_desc (0))
/// - dst_iter (#dnnl::primitive_desc_base::dst_desc (1)), if used
/// - dst_iter_c (#dnnl::primitive_desc_base::dst_desc (2)), if used
/// - diff_dst_layer (#dnnl::primitive_desc_base::diff_dst_desc (0))
/// - diff_dst_iter
/// (#dnnl::primitive_desc_base::diff_dst_desc (1)), if used
/// - diff_dst_iter_c
/// (#dnnl::primitive_desc_base::diff_dst_desc (2)), if used
/// - workspace (#dnnl::primitive_desc_base::workspace_desc (0))
///
/// Outputs:
/// - diff_src_layer (#dnnl::primitive_desc_base::diff_src_desc (0))
/// - diff_src_iter
/// (#dnnl::primitive_desc_base::diff_src_desc (1)), if used
/// - diff_src_iter_c
/// (#dnnl::primitive_desc_base::diff_src_desc (2)), if used
/// - diff_weights_layer
/// (#dnnl::primitive_desc_base::diff_weights_desc (0))
/// - diff_weights_iter
/// (#dnnl::primitive_desc_base::diff_weights_desc (1))
/// - diff_weights_peephole
/// (#dnnl::primitive_desc_base::diff_weights_desc (2)), if used
/// - diff_weights_projection
/// (#dnnl::primitive_desc_base::diff_weights_desc (index)), if used
/// and index is:
/// - 2, if there is no diff_weights_peephole
/// - 3, otherwise
/// - diff_bias
/// (#dnnl::primitive_desc_base::diff_weights_desc (index)), if used
/// and index is:
/// - 2, if neither diff_weights_peephole nor
/// diff_weights_projection is used
/// - 3, if one of diff_weights_peephole or diff_weights_projection
/// is used
/// - 4, if both diff_weights_peephole and diff_weights_projection
/// are used
///
/// @param prop_kind Propagation kind. Must be
/// #dnnl::prop_kind::backward.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent
/// cell state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param weights_peephole_desc Memory descriptor for the weights
/// applied to the cell states (according to the Peephole LSTM
/// formula).
/// @param weights_projection_desc Memory descriptor for the weights
/// applied to the hidden states to get the recurrent projection
/// (according to the Projection LSTM formula).
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
/// cell state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input
/// vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input
/// recurrent hidden state vector.
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
/// input recurrent cell state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of
/// weights applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of
/// weights applied to the recurrent input.
/// @param diff_weights_peephole_desc Memory descriptor for the diff of
/// weights applied to the cell states (according to the Peephole
/// LSTM formula).
/// @param diff_weights_projection_desc Memory descriptor for the diff
/// of weights applied to the hidden states to get the recurrent
/// projection (according to the Projection LSTM formula).
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of
/// output vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
/// recurrent hidden state vector.
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
/// output recurrent cell state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &src_iter_c_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &weights_peephole_desc,
const memory::desc &weights_projection_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &dst_iter_c_desc,
const memory::desc &diff_src_layer_desc,
const memory::desc &diff_src_iter_desc,
const memory::desc &diff_src_iter_c_desc,
const memory::desc &diff_weights_layer_desc,
const memory::desc &diff_weights_iter_desc,
const memory::desc &diff_weights_peephole_desc,
const memory::desc &diff_weights_projection_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_layer_desc,
const memory::desc &diff_dst_iter_desc,
const memory::desc &diff_dst_iter_c_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lstm_backward_desc_init_v3(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &src_iter_c_desc.data,
&weights_layer_desc.data, &weights_iter_desc.data,
&weights_peephole_desc.data,
&weights_projection_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&dst_iter_c_desc.data, &diff_src_layer_desc.data,
&diff_src_iter_desc.data,
&diff_src_iter_c_desc.data,
&diff_weights_layer_desc.data,
&diff_weights_iter_desc.data,
&diff_weights_peephole_desc.data,
&diff_weights_projection_desc.data,
&diff_bias_desc.data, &diff_dst_layer_desc.data,
&diff_dst_iter_desc.data,
&diff_dst_iter_c_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for an LSTM backward "
"propagation primitive");
}
/// Constructs an LSTM (with or without peephole) descriptor for
/// backward propagation using @p prop_kind, @p direction, and memory
/// descriptors.
///
/// The @p src_iter_desc together with @p diff_iter_desc, @p
/// src_iter_c_desc together with @p src_iter_c_desc, @p
/// weights_peephole_desc together with @p diff_weights_peephole_desc,
/// @p bias_desc together with @p diff_bias_desc, @p dst_iter_desc
/// together with @p diff_dst_iter_desc, and @p dst_iter_c_desc
/// together with @p diff_dst_iter_c_desc, may point to a zero memory
/// descriptor. This would then indicate that the LSTM backward
/// propagation primitive should not use them and should default to
/// zero values instead.
///
/// @note
/// All memory descriptors may be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `weights_peephole` (#dnnl::primitive_desc_base::weights_desc(`2`)),
/// if used
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used and
/// LSTM is without peephole
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`3`)), if used and
/// LSTM is with peephole
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `diff_dst_iter`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
/// - `diff_dst_iter_c`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`2`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
///
/// Outputs:
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_src_iter`
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
/// - `diff_src_iter_c`
/// (#dnnl::primitive_desc_base::diff_src_desc(`2`)), if used
/// - `diff_weights_layer`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_weights_iter`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
/// - `diff_weights_peephole`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`2`)),
/// if used and LSTM is without peephole
/// - `diff_bias` (#dnnl::primitive_desc_base::diff_weights_desc(`3`)),
/// if used and LSTM is with peephole
///
/// @param prop_kind Propagation kind. Must be
/// #dnnl::prop_kind::backward.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent
/// cell state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param weights_peephole_desc Memory descriptor for the weights
/// applied to the cell states (according to the Peephole LSTM
/// formula).
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
/// cell state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input
/// vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input
/// recurrent hidden state vector.
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
/// input recurrent cell state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of
/// weights applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of
/// weights applied to the recurrent input.
/// @param diff_weights_peephole_desc Memory descriptor for the diff of
/// weights applied to the cell states (according to the Peephole
/// LSTM formula).
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of
/// output vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
/// recurrent hidden state vector.
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
/// output recurrent cell state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &src_iter_c_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &weights_peephole_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &dst_iter_c_desc,
const memory::desc &diff_src_layer_desc,
const memory::desc &diff_src_iter_desc,
const memory::desc &diff_src_iter_c_desc,
const memory::desc &diff_weights_layer_desc,
const memory::desc &diff_weights_iter_desc,
const memory::desc &diff_weights_peephole_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_layer_desc,
const memory::desc &diff_dst_iter_desc,
const memory::desc &diff_dst_iter_c_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lstm_backward_desc_init_v2(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &src_iter_c_desc.data,
&weights_layer_desc.data, &weights_iter_desc.data,
&weights_peephole_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&dst_iter_c_desc.data, &diff_src_layer_desc.data,
&diff_src_iter_desc.data,
&diff_src_iter_c_desc.data,
&diff_weights_layer_desc.data,
&diff_weights_iter_desc.data,
&diff_weights_peephole_desc.data,
&diff_bias_desc.data, &diff_dst_layer_desc.data,
&diff_dst_iter_desc.data,
&diff_dst_iter_c_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for an LSTM backward "
"propagation primitive");
}
/// Constructs an LSTM descriptor for backward propagation using @p
/// prop_kind, @p direction, and memory descriptors.
///
/// The @p src_iter_desc together with @p diff_iter_desc, @p
/// src_iter_c_desc together with @p src_iter_c_desc, @p bias_desc
/// together with @p diff_bias_desc, @p dst_iter_desc together with @p
/// diff_dst_iter_desc, and @p dst_iter_c_desc together with @p
/// diff_dst_iter_c_desc, may point to a zero memory descriptor. This
/// would then indicate that the LSTM backward propagation primitive
/// should not use them and should default to zero values instead.
///
/// @note
/// All memory descriptors may be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `src_iter_c` (#dnnl::primitive_desc_base::src_desc(`2`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `dst_iter_c` (#dnnl::primitive_desc_base::dst_desc(`2`)), if used
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `diff_dst_iter`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
/// - `diff_dst_iter_c`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`2`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
///
/// Outputs:
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_src_iter`
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
/// - `diff_src_iter_c`
/// (#dnnl::primitive_desc_base::diff_src_desc(`2`)), if used
/// - `diff_weights_layer`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_weights_iter`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
/// - `diff_bias`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
///
/// @param prop_kind Propagation kind. Must be
/// #dnnl::prop_kind::backward.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param src_iter_c_desc Memory descriptor for the input recurrent
/// cell state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
/// cell state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input
/// vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input
/// recurrent hidden state vector.
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
/// input recurrent cell state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of
/// weights applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of
/// weights applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of
/// output vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
/// recurrent hidden state vector.
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
/// output recurrent cell state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &src_iter_c_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &dst_iter_c_desc,
const memory::desc &diff_src_layer_desc,
const memory::desc &diff_src_iter_desc,
const memory::desc &diff_src_iter_c_desc,
const memory::desc &diff_weights_layer_desc,
const memory::desc &diff_weights_iter_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_layer_desc,
const memory::desc &diff_dst_iter_desc,
const memory::desc &diff_dst_iter_c_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lstm_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &src_iter_c_desc.data,
&weights_layer_desc.data, &weights_iter_desc.data,
&bias_desc.data, &dst_layer_desc.data,
&dst_iter_desc.data, &dst_iter_c_desc.data,
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
&diff_src_iter_c_desc.data,
&diff_weights_layer_desc.data,
&diff_weights_iter_desc.data, &diff_bias_desc.data,
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
&diff_dst_iter_c_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for an LSTM backward "
"propagation primitive");
}
};
/// Primitive descriptor for LSTM backward propagation.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an LSTM backward propagation
/// primitive.
///
/// @param desc Descriptor for LSTM backward propagation primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an LSTM
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const lstm_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an LSTM backward propagation
/// primitive.
///
/// @param desc Descriptor for an LSTM backward propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an LSTM
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const lstm_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, &attr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an LSTM backward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for an LSTM backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
dnnl::algorithm::vanilla_lstm) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_c_desc() const {
return rnn_base::src_iter_c_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
memory::desc weights_peephole_desc() const {
return rnn_base::weights_peephole_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
memory::desc weights_projection_desc() const {
return rnn_base::weights_projection_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc dst_iter_c_desc() const {
return rnn_base::dst_iter_c_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
memory::desc diff_src_layer_desc() const {
return rnn_base::diff_src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
memory::desc diff_src_iter_desc() const {
return rnn_base::diff_src_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
memory::desc diff_src_iter_c_desc() const {
return rnn_base::diff_src_iter_c_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
memory::desc diff_weights_layer_desc() const {
return rnn_base::diff_weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
memory::desc diff_weights_iter_desc() const {
return rnn_base::diff_weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
memory::desc diff_weights_peephole_desc() const {
return rnn_base::diff_weights_peephole_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
memory::desc diff_weights_projection_desc() const {
return rnn_base::diff_weights_projection_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
memory::desc diff_bias_desc() const {
return rnn_base::diff_bias_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
memory::desc diff_dst_layer_desc() const {
return rnn_base::diff_dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
memory::desc diff_dst_iter_desc() const {
return rnn_base::diff_dst_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
memory::desc diff_dst_iter_c_desc() const {
return rnn_base::diff_dst_iter_c_desc();
}
};
/// Default constructor. Produces an empty object.
lstm_backward() = default;
/// Constructs an LSTM backward propagation primitive.
/// @param pd Primitive descriptor for an LSTM backward propagation
/// primitive.
lstm_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// GRU forward propagation primitive.
struct gru_forward : public primitive {
/// Descriptor for a GRU forward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for a GRU forward propagation primitive.
///
/// The @p src_iter_desc, @p bias_desc, and @p dst_iter, may point to
/// a zero memory descriptor. This would then indicate that the GRU
/// forward propagation primitive should not use them and should
/// default to zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
///
/// Outputs:
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
/// must be queried for using @ref
/// dnnl::primitive_desc_base::query_md() after a corresponding
/// primitive descriptor is created
///
/// @note
/// All memory descriptors except @p src_iter_desc may be
/// initialized with an #dnnl::memory::format_tag::any value of @p
/// format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_gru_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &weights_layer_desc.data,
&weights_iter_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for a GRU forward "
"propagation primitive");
}
};
/// Primitive descriptor GRU forward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a GRU forward propagation
/// primitive.
///
/// @param desc Descriptor for a GRU forward propagation primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a GRU forward propagation
/// primitive.
///
/// @param desc Descriptor for a GRU forward propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a GRU forward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for a GRU forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference,
dnnl::algorithm::vanilla_gru) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
};
/// Default constructor. Produces an empty object.
gru_forward() = default;
/// Constructs a GRU forward propagation primitive.
/// @param pd Primitive descriptor for a GRU forward propagation
/// primitive.
gru_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// GRU backward propagation primitive.
struct gru_backward : public primitive {
/// Descriptor for a GRU backward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for a GRU backward propagation primitive.
///
/// The @p src_iter_desc together with @p diff_src_iter_desc, @p
/// bias_desc together with @p diff_bias_desc, and @p dst_iter
/// together with @p diff_dst_iter, may point to a zero memory
/// descriptor. This would then indicate that the GRU backward
/// propagation primitive should not use them and should default to
/// zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `diff_dst_iter`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
///
/// Outputs:
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_src_iter`
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
/// - `diff_weights_layer`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_weights_iter`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
/// - `diff_bias`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
///
/// @note
/// All memory descriptors may be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Must be
/// #dnnl::prop_kind::backward.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input
/// vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input
/// recurrent hidden state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of
/// weights applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of
/// weights applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of
/// output vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
/// recurrent hidden state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &diff_src_layer_desc,
const memory::desc &diff_src_iter_desc,
const memory::desc &diff_weights_layer_desc,
const memory::desc &diff_weights_iter_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_layer_desc,
const memory::desc &diff_dst_iter_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_gru_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &weights_layer_desc.data,
&weights_iter_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
&diff_weights_layer_desc.data,
&diff_weights_iter_desc.data, &diff_bias_desc.data,
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for a GRU backward "
"propagation primitive");
}
};
/// Primitive descriptor for a GRU backward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a GRU backward propagation
/// primitive.
///
/// @param desc Descriptor for a GRU backward propagation primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a GRU
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const gru_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a GRU backward propagation
/// primitive.
///
/// @param desc Descriptor for a GRU backward propagation primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a GRU
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const gru_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, &attr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a GRU backward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for a GRU backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
dnnl::algorithm::vanilla_gru) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
memory::desc diff_src_layer_desc() const {
return rnn_base::diff_src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
memory::desc diff_src_iter_desc() const {
return rnn_base::diff_src_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
memory::desc diff_weights_layer_desc() const {
return rnn_base::diff_weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
memory::desc diff_weights_iter_desc() const {
return rnn_base::diff_weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
memory::desc diff_bias_desc() const {
return rnn_base::diff_bias_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
memory::desc diff_dst_layer_desc() const {
return rnn_base::diff_dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
memory::desc diff_dst_iter_desc() const {
return rnn_base::diff_dst_iter_desc();
}
};
/// Default constructor. Produces an empty object.
gru_backward() = default;
/// Constructs a GRU backward propagation primitive.
/// @param pd Primitive descriptor for a GRU backward propagation
/// primitive.
gru_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// LBR GRU forward propagation primitive.
struct lbr_gru_forward : public primitive {
/// Descriptor for an LBR GRU forward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for LBR GRU forward propagation primitive.
///
/// The @p src_iter_desc, @p bias_desc, and @p dst_iter, may point to
/// a zero memory descriptor. This would then indicate that the LBR
/// GRU forward propagation primitive should not use them and should
/// default to zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
///
/// Outputs:
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`)),
/// if @p prop_kind equals #dnnl::prop_kind::forward_training;
/// must be queried for using @ref
/// dnnl::primitive_desc_base::query_md() after a corresponding
/// primitive descriptor is created
///
/// @note
/// All memory descriptors except @p src_iter_desc may be
/// initialized with an #dnnl::memory::format_tag::any value of @p
/// format_tag.
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lbr_gru_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &weights_layer_desc.data,
&weights_iter_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for an LBR GRU forward "
"propagation primitive");
}
};
/// Primitive descriptor for an LBR GRU forward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a LBR GRU forward
/// propagation primitive.
///
/// @param desc Descriptor for a LBR GRU forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a LBR GRU forward
/// propagation primitive.
///
/// @param desc Descriptor for a LBR GRU forward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: rnn_primitive_desc_base(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a LBR GRU forward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for a LBR GRU forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference,
dnnl::algorithm::lbr_gru) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
};
/// Default constructor. Produces an empty object.
lbr_gru_forward() = default;
/// Constructs an LBR GRU forward propagation primitive.
/// @param pd Primitive descriptor for an LBR GRU forward propagation
/// primitive.
lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// LBR GRU backward propagation primitive.
struct lbr_gru_backward : public primitive {
/// Descriptor for a LBR GRU backward propagation primitive.
struct desc {
dnnl_rnn_desc_t data;
/// Constructs a descriptor for LBR GRU backward propagation
/// primitive.
///
/// The @p src_iter_desc together with @p diff_src_iter_desc, @p
/// bias_desc together with @p diff_bias_desc, and @p dst_iter
/// together with @p diff_dst_iter, may point to a zero memory
/// descriptor. This would then indicate that the LBR GRU backward
/// propagation primitive should not use them and should default to
/// zero values instead.
///
/// Inputs:
/// - `src_layer` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src_iter` (#dnnl::primitive_desc_base::src_desc(`1`)), if used
/// - `weights_layer` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `weights_iter` (#dnnl::primitive_desc_base::weights_desc(`1`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`2`)), if used
/// - `dst_layer` (#dnnl::primitive_desc_base::dst_desc(`0`))
/// - `dst_iter` (#dnnl::primitive_desc_base::dst_desc(`1`)), if used
/// - `diff_dst_layer` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
/// - `diff_dst_iter`
/// (#dnnl::primitive_desc_base::diff_dst_desc(`1`)), if used
/// - `workspace` (#dnnl::primitive_desc_base::workspace_desc(`0`))
///
/// Outputs:
/// - `diff_src_layer` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
/// - `diff_src_iter`
/// (#dnnl::primitive_desc_base::diff_src_desc(`1`)), if used
/// - `diff_weights_layer`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`0`))
/// - `diff_weights_iter`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`1`))
/// - `diff_bias`
/// (#dnnl::primitive_desc_base::diff_weights_desc(`2`)), if used
///
/// @note
/// All memory descriptors may be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
///
/// @param prop_kind Propagation kind. Must be
/// #dnnl::prop_kind::backward.
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
/// more info.
/// @param src_layer_desc Memory descriptor for the input vector.
/// @param src_iter_desc Memory descriptor for the input recurrent
/// hidden state vector.
/// @param weights_layer_desc Memory descriptor for the weights
/// applied to the layer input.
/// @param weights_iter_desc Memory descriptor for the weights applied
/// to the recurrent input.
/// @param bias_desc Bias memory descriptor.
/// @param dst_layer_desc Memory descriptor for the output vector.
/// @param dst_iter_desc Memory descriptor for the output recurrent
/// hidden state vector.
/// @param diff_src_layer_desc Memory descriptor for the diff of input
/// vector.
/// @param diff_src_iter_desc Memory descriptor for the diff of input
/// recurrent hidden state vector.
/// @param diff_weights_layer_desc Memory descriptor for the diff of
/// weights applied to the layer input.
/// @param diff_weights_iter_desc Memory descriptor for the diff of
/// weights applied to the recurrent input.
/// @param diff_bias_desc Diff bias memory descriptor.
/// @param diff_dst_layer_desc Memory descriptor for the diff of
/// output vector.
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
/// recurrent hidden state vector.
/// @param flags Unused.
desc(prop_kind prop_kind, rnn_direction direction,
const memory::desc &src_layer_desc,
const memory::desc &src_iter_desc,
const memory::desc &weights_layer_desc,
const memory::desc &weights_iter_desc,
const memory::desc &bias_desc,
const memory::desc &dst_layer_desc,
const memory::desc &dst_iter_desc,
const memory::desc &diff_src_layer_desc,
const memory::desc &diff_src_iter_desc,
const memory::desc &diff_weights_layer_desc,
const memory::desc &diff_weights_iter_desc,
const memory::desc &diff_bias_desc,
const memory::desc &diff_dst_layer_desc,
const memory::desc &diff_dst_iter_desc,
rnn_flags flags = rnn_flags::undef) {
error::wrap_c_api(
dnnl_lbr_gru_backward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
dnnl::convert_to_c(direction), &src_layer_desc.data,
&src_iter_desc.data, &weights_layer_desc.data,
&weights_iter_desc.data, &bias_desc.data,
&dst_layer_desc.data, &dst_iter_desc.data,
&diff_src_layer_desc.data, &diff_src_iter_desc.data,
&diff_weights_layer_desc.data,
&diff_weights_iter_desc.data, &diff_bias_desc.data,
&diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
dnnl::convert_to_c(flags)),
"could not create a descriptor for an LBR GRU backward "
"propagation primitive");
}
};
/// Primitive descriptor for an LBR GRU backward propagation primitive.
struct primitive_desc : public rnn_primitive_desc_base {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an LBR GRU backward
/// propagation primitive.
///
/// @param desc Descriptor for an LBR GRU backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an LBR GRU
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const lbr_gru_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for an LBR GRU backward
/// propagation primitive.
///
/// @param desc Descriptor for an LBR GRU backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for an LBR GRU
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const lbr_gru_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: rnn_primitive_desc_base(&desc.data, &attr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a LBR GRU backward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for a LBR GRU backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: rnn_primitive_desc_base(
pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
memory::desc src_layer_desc() const {
return rnn_base::src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
memory::desc weights_layer_desc() const {
return rnn_base::weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
memory::desc weights_iter_desc() const {
return rnn_base::weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
memory::desc dst_layer_desc() const {
return rnn_base::dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
memory::desc workspace_desc() const {
return rnn_base::workspace_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
memory::desc diff_src_layer_desc() const {
return rnn_base::diff_src_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
memory::desc diff_src_iter_desc() const {
return rnn_base::diff_src_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
memory::desc diff_weights_layer_desc() const {
return rnn_base::diff_weights_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
memory::desc diff_weights_iter_desc() const {
return rnn_base::diff_weights_iter_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
memory::desc diff_bias_desc() const {
return rnn_base::diff_bias_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
memory::desc diff_dst_layer_desc() const {
return rnn_base::diff_dst_layer_desc();
}
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
memory::desc diff_dst_iter_desc() const {
return rnn_base::diff_dst_iter_desc();
}
};
/// Default constructor. Produces an empty object.
lbr_gru_backward() = default;
/// Constructs an LBR GRU backward propagation primitive.
/// @param pd Primitive descriptor for an LBR GRU backward propagation
/// primitive.
lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_rnn
/// @addtogroup dnnl_api_shuffle Shuffle
///
/// A primitive to shuffle tensor data along an axis.
///
/// @sa @ref dev_guide_shuffle in developer guide
///
/// @{
/// Shuffle forward propagation primitive.
struct shuffle_forward : public primitive {
/// Descriptor for a shuffle forward propagation primitive.
struct desc {
dnnl_shuffle_desc_t data;
/// Constructs a descriptor for a shuffle forward propagation
/// primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param data_desc Source and destination memory descriptor.
/// @param axis The axis along which the data is shuffled.
/// @param group_size Shuffle group size.
desc(prop_kind prop_kind, const memory::desc &data_desc, int axis,
int group_size) {
error::wrap_c_api(dnnl_shuffle_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
&data_desc.data, axis, group_size),
"could not create a descriptor for a shuffle forward "
"propagation primitive");
}
};
/// Primitive descriptor for a shuffle forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a shuffle forward
/// propagation primitive.
///
/// @param desc Descriptor for a shuffle forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const primitive_attr &attr = primitive_attr(),
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a shuffle forward propagation
/// primitive from a C API primitive descriptor that must have a
/// matching kind.
///
/// @param pd C API primitive descriptor for a shuffle forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
shuffle_forward() = default;
/// Constructs a shuffle forward propagation primitive.
/// @param pd Primitive descriptor for a shuffle forward propagation
/// primitive.
shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Shuffle backward propagation primitive.
struct shuffle_backward : public primitive {
/// Descriptor for a shuffle primitive backward propagation
/// primitive.
struct desc {
dnnl_shuffle_desc_t data;
/// Constructs a descriptor for a shuffle backward propagation
/// primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param diff_data_desc Diff source and diff destination memory
/// descriptor.
/// @param axis The axis along which the data is shuffled.
/// @param group_size Shuffle group size.
desc(const memory::desc &diff_data_desc, int axis, int group_size) {
error::wrap_c_api(dnnl_shuffle_backward_desc_init(&data,
&diff_data_desc.data, axis, group_size),
"could not create a descriptor for a shuffle backward "
"propagation primitive");
}
};
/// Primitive descriptor for a shuffle backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a shuffle backward
/// propagation primitive.
///
/// @param desc Descriptor for a shuffle backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param hint_fwd_pd Primitive descriptor for a shuffle
/// forward propagation primitive. It is used as a hint for
/// deciding which memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const shuffle_forward::primitive_desc &hint_fwd_pd,
const primitive_attr &attr = primitive_attr(),
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a shuffle backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a shuffle backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
shuffle_backward() = default;
/// Constructs a shuffle backward propagation primitive.
/// @param pd Primitive descriptor for a shuffle backward propagation
/// primitive.
shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_shuffle
/// @addtogroup dnnl_api_binary Binary
///
/// A primitive to perform tensor operations over two tensors.
///
/// @sa @ref dev_guide_binary in developer guide
///
/// @{
/// Elementwise binary operator primitive.
struct binary : public primitive {
/// Descriptor for an elementwise binary operator primitive.
struct desc {
/// Underlying C operation descriptor.
dnnl_binary_desc_t data;
/// Default constructor. Produces an empty object.
desc() = default;
/// Constructs a descriptor for an elementwise binary operator
/// primitive.
///
/// Inputs:
/// - `src0` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `src1` (#dnnl::primitive_desc_base::src_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param algorithm Elementwise algorithm.
/// @param src0 Memory descriptor for source tensor #0.
/// @param src1 Memory descriptor for source tensor #1.
/// @param dst Memory descriptor for destination tensor.
desc(algorithm algorithm, const memory::desc &src0,
const memory::desc &src1, const memory::desc &dst) {
error::wrap_c_api(
dnnl_binary_desc_init(&data, dnnl::convert_to_c(algorithm),
&src0.data, &src1.data, &dst.data),
"could not create a descriptor for a binary operation "
"primitive");
}
};
/// Primitive descriptor for an elementwise binary operator primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for an elementwise binary operator
/// primitive.
///
/// @param desc Descriptor for an elementwise binary operator primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for an elementwise binary operator
/// primitive.
///
/// @param desc Descriptor for an elementwise binary operator primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a binary primitive from a C
/// API primitive descriptor that must have a matching kind.
///
/// @param pd C API primitive descriptor for a binary primitve.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
/// Returns the memory descriptor for source #0.
memory::desc src0_desc() const { return base::src_desc(0); }
/// Returns the memory descriptor for source #1.
memory::desc src1_desc() const { return base::src_desc(1); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
binary() = default;
/// Constructs an elementwise binary operation primitive.
/// @param pd Primitive descriptor for an elementwise binary operation
/// primitive.
binary(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_binary
/// @addtogroup dnnl_api_matmul Matrix Multiplication
///
/// A primitive to perform matrix-matrix multiplication. The batched mode
/// is supported with 3D tensors.
///
/// @sa @ref dev_guide_matmul in developer guide
///
///
/// @{
/// Matrix multiplication (matmul) primitive.
struct matmul : public primitive {
/// Descriptor for a matmul primitive.
struct desc {
dnnl_matmul_desc_t data;
/// Constructs a descriptor for a matmul primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param src_desc Memory descriptor for source (matrix A).
/// @param weights_desc Memory descriptor for weights (matrix B).
/// @param dst_desc Memory descriptor for destination (matrix C).
desc(const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &dst_desc) {
error::wrap_c_api(
dnnl_matmul_desc_init(&data, &src_desc.data,
&weights_desc.data, nullptr, &dst_desc.data),
"could not create a descriptor for a matmul primitive");
}
/// Constructs a descriptor for a matmul primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
/// - `weights` (#dnnl::primitive_desc_base::weights_desc(`0`))
/// - `bias` (#dnnl::primitive_desc_base::weights_desc(`1`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @param src_desc Memory descriptor for source (matrix A).
/// @param weights_desc Memory descriptor for weights (matrix B).
/// @param dst_desc Memory descriptor for destination (matrix C).
/// @param bias_desc Memory descriptor for bias.
desc(const memory::desc &src_desc, const memory::desc &weights_desc,
const memory::desc &bias_desc, const memory::desc &dst_desc) {
error::wrap_c_api(dnnl_matmul_desc_init(&data, &src_desc.data,
&weights_desc.data, &bias_desc.data,
&dst_desc.data),
"could not create a descriptor for a matmul primitive");
}
};
/// Primitive descriptor for a matmul primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a matmul primitive.
///
/// @param desc Descriptor for a matmul primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a matmul primitive.
///
/// @param desc Descriptor for a matmul primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a matmul primitive from a C
/// API primitive descriptor that must have a matching kind.
///
/// @param pd C API primitive descriptor for a matmul primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return query_md(query::src_md, 0); }
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
memory::desc weights_desc() const {
return query_md(query::weights_md, 0);
}
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
memory::desc bias_desc() const {
return query_md(query::weights_md, 1);
}
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
};
/// Default constructor. Produces an empty object.
matmul() = default;
/// Constructs a matmul primitive.
/// @param pd Primitive descriptor for a matmul primitive.
matmul(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_matmul
/// @addtogroup dnnl_api_resampling Resampling
///
/// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
/// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
/// method.
///
/// @sa @ref dev_guide_resampling in developer guide
///
/// @{
/// Resampling forward propagation.
struct resampling_forward : public primitive {
/// Descriptor for resampling forward propagation.
struct desc {
dnnl_resampling_desc_t data;
/// Constructs a descriptor for a resampling forward propagation
/// primitive using source and destination memory descriptors.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// Destination memory descriptor may be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
//
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm resampling algorithm kind: either
/// #dnnl::algorithm::resampling_nearest, or
/// #dnnl::algorithm::resampling_linear
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
desc(prop_kind prop_kind, algorithm algorithm,
const memory::desc &src_desc, const memory::desc &dst_desc) {
error::wrap_c_api(dnnl_resampling_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), nullptr,
&src_desc.data, &dst_desc.data),
"could not create a resampling forward descriptor");
}
/// Constructs a descriptor for a resampling forward propagation
/// primitive using source memory descriptor and factors.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm resampling algorithm kind: either
/// #dnnl::algorithm::resampling_nearest, or
/// #dnnl::algorithm::resampling_linear
/// @param factors Vector of scaling factors for spatial dimension.
/// @param src_desc Source memory descriptor.
desc(prop_kind prop_kind, algorithm algorithm,
const std::vector<float> &factors,
const memory::desc &src_desc) {
memory::validate_dims(factors, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_resampling_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), &factors[0],
&src_desc.data, nullptr),
"could not create a resampling forward descriptor");
}
/// Constructs a descriptor for a resampling forward propagation
/// primitive.
///
/// Inputs:
/// - `src` (#dnnl::primitive_desc_base::src_desc(`0`))
///
/// Outputs:
/// - `dst` (#dnnl::primitive_desc_base::dst_desc(`0`))
///
/// @note
/// Destination memory descriptor may be initialized with
/// #dnnl::memory::format_tag::any value of @p format_tag.
//
/// @param prop_kind Propagation kind. Possible values are
/// #dnnl::prop_kind::forward_training, and
/// #dnnl::prop_kind::forward_inference.
/// @param algorithm resampling algorithm kind: either
/// #dnnl::algorithm::resampling_nearest, or
/// #dnnl::algorithm::resampling_linear
/// @param factors Vector of scaling factors for spatial dimension.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
desc(prop_kind prop_kind, algorithm algorithm,
const std::vector<float> &factors, const memory::desc &src_desc,
const memory::desc &dst_desc) {
if (!factors.empty())
memory::validate_dims(factors, src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_resampling_forward_desc_init(&data,
dnnl::convert_to_c(prop_kind),
convert_to_c(algorithm), factors.data(),
&src_desc.data, &dst_desc.data),
"could not create a resampling forward descriptor");
}
};
/// Primitive descriptor for a resampling forward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a resampling forward
/// propagation primitive.
///
/// @param desc Descriptor for a resampling forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, nullptr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a resampling forward
/// propagation primitive.
///
/// @param desc Descriptor for a resampling forward propagation
/// primitive.
/// @param engine Engine to use.
/// @param attr Primitive attributes to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine, bool allow_empty = false)
: dnnl::primitive_desc(
&desc.data, &attr, engine, nullptr, allow_empty) {}
/// Constructs a primitive descriptor for a resampling forward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a resampling forward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
dnnl::prop_kind::forward_training,
dnnl::prop_kind::forward_inference) {}
/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};
/// Default constructor. Produces an empty object.
resampling_forward() = default;
/// Constructs a resampling forward propagation primitive.
/// @param pd Primitive descriptor for a resampling forward propagation
/// primitive.
resampling_forward(const primitive_desc &pd) : primitive(pd) {}
};
/// Resampling backward propagation primitive.
struct resampling_backward : public primitive {
/// Descriptor for a resampling backward propagation primitive.
struct desc {
dnnl_resampling_desc_t data;
/// Constructs a descriptor for a resampling backward propagation
/// primitive using source and destination memory descriptors.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param algorithm resampling algorithm kind: either
/// #dnnl::algorithm::resampling_nearest, or
/// #dnnl::algorithm::resampling_linear
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
desc(algorithm algorithm, const memory::desc &diff_src_desc,
const memory::desc &diff_dst_desc) {
error::wrap_c_api(dnnl_resampling_backward_desc_init(&data,
convert_to_c(algorithm), nullptr,
&diff_src_desc.data, &diff_dst_desc.data),
"could not create a resampling backward data descriptor");
}
/// Constructs a descriptor for resampling backward propagation
/// primitive.
///
/// Inputs:
/// - `diff_dst` (#dnnl::primitive_desc_base::diff_dst_desc(`0`))
///
/// Outputs:
/// - `diff_src` (#dnnl::primitive_desc_base::diff_src_desc(`0`))
///
/// @param algorithm resampling algorithm kind: either
/// #dnnl::algorithm::resampling_nearest, or
/// #dnnl::algorithm::resampling_linear
/// @param factors Vector of scaling factors for spatial dimension.
/// @param diff_src_desc Diff source memory descriptor.
/// @param diff_dst_desc Diff destination memory descriptor.
desc(algorithm algorithm, const std::vector<float> &factors,
const memory::desc &diff_src_desc,
const memory::desc &diff_dst_desc) {
if (!factors.empty())
memory::validate_dims(factors, diff_src_desc.data.ndims - 2);
error::wrap_c_api(dnnl_resampling_backward_desc_init(&data,
convert_to_c(algorithm), factors.data(),
&diff_src_desc.data, &diff_dst_desc.data),
"could not create a resampling backward data descriptor");
}
};
/// Primitive descriptor for resampling backward propagation primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;
/// Constructs a primitive descriptor for a resampling backward
/// propagation primitive.
///
/// @param desc Descriptor for a resampling backward propagation
/// primitive.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a resampling forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const engine &engine,
const resampling_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, nullptr, engine,
hint_fwd_pd.get(), allow_empty) {}
/// Constructs a primitive descriptor for a resampling backward
/// propagation primitive.
///
/// @param desc Descriptor for a resampling backward propagation
/// primitive.
/// @param attr Primitive attributes to use.
/// @param engine Engine to use.
/// @param hint_fwd_pd Primitive descriptor for a resampling forward
/// propagation primitive. It is used as a hint for deciding which
/// memory format to use.
/// @param allow_empty A flag signifying whether construction is
/// allowed to fail without throwing an exception. In this case an
/// empty object will be produced. This flag is optional and
/// defaults to false.
primitive_desc(const desc &desc, const primitive_attr &attr,
const engine &engine,
const resampling_forward::primitive_desc &hint_fwd_pd,
bool allow_empty = false)
: dnnl::primitive_desc(&desc.data, &attr, engine, hint_fwd_pd.get(),
allow_empty) {}
/// Constructs a primitive descriptor for a resampling backward
/// propagation primitive from a C API primitive descriptor that must
/// have a matching kind.
///
/// @param pd C API primitive descriptor for a resampling backward
/// propagation primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
dnnl::prop_kind::backward_data) {}
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
};
/// Default constructor. Produces an empty object.
resampling_backward() = default;
/// Constructs a resampling backward propagation primitive.
/// @param pd Primitive descriptor for a resampling backward propagation
/// primitive.
resampling_backward(const primitive_desc &pd) : primitive(pd) {}
};
/// @} dnnl_api_resampling
/// @} dnnl_api_primitives
/// @addtogroup dnnl_api_service Service
///
/// A set of functions that aid in oneDNN debugging and profiling.
///
/// @{
/// @copydoc dnnl_version_t
using version_t = dnnl_version_t;
/// Status values returned by the library functions.
enum class status {
/// @copydoc dnnl_success
success = dnnl_success,
/// @copydoc dnnl_out_of_memory
out_of_memory = dnnl_out_of_memory,
/// @copydoc dnnl_invalid_arguments
invalid_arguments = dnnl_invalid_arguments,
/// @copydoc dnnl_unimplemented
unimplemented = dnnl_unimplemented,
/// @copydoc dnnl_iterator_ends
iterator_ends = dnnl_iterator_ends,
/// @copydoc dnnl_runtime_error
runtime_error = dnnl_runtime_error,
/// @copydoc dnnl_not_required
not_required = dnnl_not_required,
};
/// @copydoc dnnl_set_verbose()
inline status set_verbose(int level) {
return static_cast<status>(dnnl_set_verbose(level));
}
/// @copydoc dnnl_version()
inline const version_t *version() {
return dnnl_version();
}
/// @copydoc dnnl_set_jit_dump()
inline status set_jit_dump(int enable) {
return static_cast<status>(dnnl_set_jit_dump(enable));
}
/// @copydoc dnnl_set_jit_profiling_flags()
inline status set_jit_profiling_flags(unsigned flags) {
return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
}
/// @copydoc dnnl_set_jit_profiling_jitdumpdir()
inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
}
/// @copydoc dnnl_cpu_isa_t
enum class cpu_isa {
/// @copydoc dnnl_cpu_isa_all
all = dnnl_cpu_isa_all,
/// @copydoc dnnl_cpu_isa_sse41
sse41 = dnnl_cpu_isa_sse41,
/// @copydoc dnnl_cpu_isa_avx
avx = dnnl_cpu_isa_avx,
/// @copydoc dnnl_cpu_isa_avx2
avx2 = dnnl_cpu_isa_avx2,
/// @copydoc dnnl_cpu_isa_avx512_mic
avx512_mic = dnnl_cpu_isa_avx512_mic,
/// @copydoc dnnl_cpu_isa_avx512_mic_4ops
avx512_mic_4ops = dnnl_cpu_isa_avx512_mic_4ops,
/// @copydoc dnnl_cpu_isa_avx512_core
avx512_core = dnnl_cpu_isa_avx512_core,
/// @copydoc dnnl_cpu_isa_avx512_core_vnni
avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
/// @copydoc dnnl_cpu_isa_avx512_core_bf16
avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
};
/// @copydoc dnnl_set_max_cpu_isa()
inline status set_max_cpu_isa(cpu_isa isa) {
return static_cast<status>(
dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
}
/// @} dnnl_api_service
/// @addtogroup dnnl_api_blas BLAS functions
///
/// A subset of Basic Linear ALgebra (BLAS) functions that perform
/// matrix-matrix multiplication.
///
/// @{
/// @copydoc dnnl_sgemm()
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
return static_cast<status>(dnnl_sgemm(
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
}
/// @copydoc dnnl_gemm_u8s8s32()
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
}
/// @copydoc dnnl_gemm_s8s8s32()
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
}
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
/// @copydoc dnnl_sgemm_tp()
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc,
dnnl::threadpool_iface *tp) {
return static_cast<status>(dnnl_sgemm_tp(
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, tp));
}
/// @copydoc dnnl_gemm_u8s8s32_tp()
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
dnnl::threadpool_iface *tp) {
return static_cast<status>(dnnl_gemm_u8s8s32_tp(transa, transb, offsetc, M,
N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co, tp));
}
/// @copydoc dnnl_gemm_s8s8s32_tp()
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
dnnl::threadpool_iface *tp) {
return static_cast<status>(dnnl_gemm_s8s8s32_tp(transa, transb, offsetc, M,
N, K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co, tp));
}
#endif
/// @} dnnl_api_blas
// implementation section
/// @cond DO_NOT_DOCUMENT_THIS
inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
dnnl_primitive_t result;
error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
"could not create a primitive");
reset(result);
}
inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
inline void primitive::execute(const stream &stream,
const std::unordered_map<int, memory> &args) const {
std::vector<dnnl_exec_arg_t> c_args;
c_args.reserve(args.size());
for (const auto &a : args)
c_args.push_back({a.first, a.second.get(true)});
error::wrap_c_api(dnnl_primitive_execute(get(), stream.get(),
(int)c_args.size(), c_args.data()),
"could not execute a primitive");
}
/// @endcond
#undef DNNL_DEFINE_BITMASK_OPS
} // namespace dnnl
/// @} dnnl_api
#endif