screwdriver/cuda/include/cudnn_adv_infer.h

659 lines
28 KiB
C
Raw Permalink Normal View History

2025-02-06 16:10:58 +08:00
/*
* Copyright 1993-2020 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
/* cudnn_adv_infer : cuDNN's advanced and experimental features.
*/
#if !defined(CUDNN_ADV_INFER_H_)
#define CUDNN_ADV_INFER_H_
#include <cuda_runtime.h>
#include <stdint.h>
#include "cudnn_version.h"
#include "cudnn_ops_infer.h"
/* These version numbers are autogenerated, do not edit manually. */
#define CUDNN_ADV_INFER_MAJOR 8
#define CUDNN_ADV_INFER_MINOR 1
#define CUDNN_ADV_INFER_PATCH 0
#if (CUDNN_ADV_INFER_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_INFER_MINOR != CUDNN_MINOR) || \
(CUDNN_ADV_INFER_PATCH != CUDNN_PATCHLEVEL)
#error Version mismatch in cuDNN ADV INFER!!!
#endif
#if defined(__cplusplus)
extern "C" {
#endif
/* BASIC RNN API */
typedef enum {
CUDNN_FWD_MODE_INFERENCE = 0,
CUDNN_FWD_MODE_TRAINING = 1,
} cudnnForwardMode_t;
typedef enum {
CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
CUDNN_LSTM = 2, /* LSTM with optional recurrent projection and clipping */
CUDNN_GRU = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
} cudnnRNNMode_t;
typedef enum {
CUDNN_RNN_NO_BIAS = 0, /* rnn cell formulas do not use biases */
CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
CUDNN_RNN_DOUBLE_BIAS = 2, /* default, rnn cell formulas use two bias vectors */
CUDNN_RNN_SINGLE_REC_BIAS = 3 /* rnn cell formulas use one recurrent bias in recurrent GEMM */
} cudnnRNNBiasMode_t;
typedef enum {
CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
CUDNN_BIDIRECTIONAL = 1, /* output concatination at each layer */
} cudnnDirectionMode_t;
typedef enum {
CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
CUDNN_SKIP_INPUT = 1, /* fixed identity matrix in the first layer input GEMM */
} cudnnRNNInputMode_t;
typedef enum {
CUDNN_RNN_CLIP_NONE = 0, /* disables LSTM cell clipping */
CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
} cudnnRNNClipMode_t;
typedef enum {
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, /* padded, outer stride from one time-step to the next */
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, /* sequence length sorted and packed as in basic RNN api */
CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
} cudnnRNNDataLayout_t;
/* Legacy type for backward compatibility */
typedef unsigned cudnnRNNPaddingMode_t;
/* For auxFlags in cudnnSetRNNDescriptor_v8() and cudnnSetRNNPaddingMode() */
#define CUDNN_RNN_PADDED_IO_DISABLED 0
#define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)
struct cudnnRNNStruct;
typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;
struct cudnnPersistentRNNPlan;
typedef struct cudnnPersistentRNNPlan *cudnnPersistentRNNPlan_t;
struct cudnnRNNDataStruct;
typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;
cudnnStatus_t CUDNNWINAPI
cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);
cudnnStatus_t CUDNNWINAPI
cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);
cudnnStatus_t CUDNNWINAPI
cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
cudnnRNNAlgo_t algo,
cudnnRNNMode_t cellMode,
cudnnRNNBiasMode_t biasMode,
cudnnDirectionMode_t dirMode,
cudnnRNNInputMode_t inputMode,
cudnnDataType_t dataType,
cudnnDataType_t mathPrec,
cudnnMathType_t mathType,
int32_t inputSize,
int32_t hiddenSize,
int32_t projSize,
int32_t numLayers,
cudnnDropoutDescriptor_t dropoutDesc,
uint32_t auxFlags);
cudnnStatus_t CUDNNWINAPI
cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
cudnnRNNAlgo_t *algo,
cudnnRNNMode_t *cellMode,
cudnnRNNBiasMode_t *biasMode,
cudnnDirectionMode_t *dirMode,
cudnnRNNInputMode_t *inputMode,
cudnnDataType_t *dataType,
cudnnDataType_t *mathPrec,
cudnnMathType_t *mathType,
int32_t *inputSize,
int32_t *hiddenSize,
int32_t *projSize,
int32_t *numLayers,
cudnnDropoutDescriptor_t *dropoutDesc,
uint32_t *auxFlags);
/*
* mathPrec in cudnnSetRNNDescriptor_v6() specifies compute precision
* compute precision is further modified by cudnnSetRNNMatrixMathType()
* dataType in cudnnGetRNNParamsSize() and wDesc specify weight storage
* dropout is between RNN layers, not between recurrent steps
*/
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetRNNDescriptor_v6(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
const int hiddenSize,
const int numLayers,
cudnnDropoutDescriptor_t dropoutDesc,
cudnnRNNInputMode_t inputMode,
cudnnDirectionMode_t direction,
cudnnRNNMode_t cellMode,
cudnnRNNAlgo_t algo,
cudnnDataType_t mathPrec);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNDescriptor_v6(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
int *hiddenSize,
int *numLayers,
cudnnDropoutDescriptor_t *dropoutDesc,
cudnnRNNInputMode_t *inputMode,
cudnnDirectionMode_t *direction,
cudnnRNNMode_t *cellMode,
cudnnRNNAlgo_t *algo,
cudnnDataType_t *mathPrec);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t biasMode);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t *biasMode);
cudnnStatus_t CUDNNWINAPI
cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
cudnnRNNClipMode_t clipMode,
cudnnNanPropagation_t clipNanOpt,
double lclip,
double rclip);
cudnnStatus_t CUDNNWINAPI
cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
cudnnRNNClipMode_t *clipMode,
cudnnNanPropagation_t *clipNanOpt,
double *lclip,
double *rclip);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnRNNSetClip(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
cudnnRNNClipMode_t clipMode,
cudnnNanPropagation_t clipNanOpt,
double lclip,
double rclip);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnRNNGetClip(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
cudnnRNNClipMode_t *clipMode,
cudnnNanPropagation_t *clipNanOpt,
double *lclip,
double *rclip);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetRNNProjectionLayers(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
const int recProjSize,
const int outProjSize);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNProjectionLayers(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
int *recProjSize,
int *outProjSize);
/* Expensive. Creates the plan for the specific settings. */
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc,
const int minibatch,
const cudnnDataType_t dataType,
cudnnPersistentRNNPlan_t *plan);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan);
cudnnStatus_t CUDNNWINAPI
cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);
/* dataType in weight descriptors and input descriptors is used to describe storage */
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNWorkspaceSize(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int seqLength,
const cudnnTensorDescriptor_t *xDesc,
size_t *sizeInBytes);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNTrainingReserveSize(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int seqLength,
const cudnnTensorDescriptor_t *xDesc,
size_t *sizeInBytes);
cudnnStatus_t CUDNNWINAPI
cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
cudnnForwardMode_t fMode,
cudnnRNNDataDescriptor_t xDesc,
size_t *workSpaceSize,
size_t *reserveSpaceSize);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNParamsSize(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const cudnnTensorDescriptor_t xDesc,
size_t *sizeInBytes,
cudnnDataType_t dataType);
cudnnStatus_t CUDNNWINAPI
cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNLinLayerMatrixParams(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int pseudoLayer,
const cudnnTensorDescriptor_t xDesc,
const cudnnFilterDescriptor_t wDesc,
const void *w,
const int linLayerID,
cudnnFilterDescriptor_t linLayerMatDesc,
void **linLayerMat);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNLinLayerBiasParams(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int pseudoLayer,
const cudnnTensorDescriptor_t xDesc,
const cudnnFilterDescriptor_t wDesc,
const void *w,
const int linLayerID,
cudnnFilterDescriptor_t linLayerBiasDesc,
void **linLayerBias);
cudnnStatus_t CUDNNWINAPI
cudnnGetRNNWeightParams(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
int32_t pseudoLayer,
size_t weightSpaceSize,
const void *weightSpace,
int32_t linLayerID,
cudnnTensorDescriptor_t mDesc,
void **mAddr,
cudnnTensorDescriptor_t bDesc,
void **bAddr);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnRNNForwardInference(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int seqLength,
const cudnnTensorDescriptor_t *xDesc,
const void *x,
const cudnnTensorDescriptor_t hxDesc,
const void *hx,
const cudnnTensorDescriptor_t cxDesc,
const void *cx,
const cudnnFilterDescriptor_t wDesc,
const void *w,
const cudnnTensorDescriptor_t *yDesc,
void *y,
const cudnnTensorDescriptor_t hyDesc,
void *hy,
const cudnnTensorDescriptor_t cyDesc,
void *cy,
void *workSpace,
size_t workSpaceSizeInBytes);
/* RNN EX API */
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, unsigned paddingMode);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, unsigned *paddingMode);
cudnnStatus_t CUDNNWINAPI
cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);
cudnnStatus_t CUDNNWINAPI
cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);
cudnnStatus_t CUDNNWINAPI
cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
cudnnDataType_t dataType,
cudnnRNNDataLayout_t layout,
int maxSeqLength,
int batchSize,
int vectorSize,
const int seqLengthArray[], /* length of each sequence in the batch */
void *paddingFill); /* symbol for filling padding position in output */
cudnnStatus_t CUDNNWINAPI
cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
cudnnDataType_t *dataType,
cudnnRNNDataLayout_t *layout,
int *maxSeqLength,
int *batchSize,
int *vectorSize,
int arrayLengthRequested,
int seqLengthArray[],
void *paddingFill);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnRNNForwardInferenceEx(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const cudnnRNNDataDescriptor_t xDesc,
const void *x,
const cudnnTensorDescriptor_t hxDesc,
const void *hx,
const cudnnTensorDescriptor_t cxDesc,
const void *cx,
const cudnnFilterDescriptor_t wDesc,
const void *w,
const cudnnRNNDataDescriptor_t yDesc,
void *y,
const cudnnTensorDescriptor_t hyDesc,
void *hy,
const cudnnTensorDescriptor_t cyDesc,
void *cy,
const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */
const void *keys, /* reserved, should pass NULL */
const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */
void *cAttn, /* reserved, should pass NULL */
const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */
void *iAttn, /* reserved, should pass NULL */
const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */
void *queries, /* reserved, should pass NULL */
void *workSpace,
size_t workSpaceSizeInBytes);
cudnnStatus_t CUDNNWINAPI
cudnnRNNForward(cudnnHandle_t handle,
cudnnRNNDescriptor_t rnnDesc,
cudnnForwardMode_t fwdMode,
const int32_t devSeqLengths[],
cudnnRNNDataDescriptor_t xDesc,
const void *x,
cudnnRNNDataDescriptor_t yDesc,
void *y,
cudnnTensorDescriptor_t hDesc,
const void *hx,
void *hy,
cudnnTensorDescriptor_t cDesc,
const void *cx,
void *cy,
size_t weightSpaceSize,
const void *weightSpace,
size_t workSpaceSize,
void *workSpace,
size_t reserveSpaceSize,
void *reserveSpace);
/* RNN FIND API */
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, cudnnAlgorithmDescriptor_t algoDesc);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetRNNForwardInferenceAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count);
CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnFindRNNForwardInferenceAlgorithmEx(cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnnDesc,
const int seqLength,
const cudnnTensorDescriptor_t *xDesc,
const void *x,
const cudnnTensorDescriptor_t hxDesc,
const void *hx,
const cudnnTensorDescriptor_t cxDesc,
const void *cx,
const cudnnFilterDescriptor_t wDesc,
const void *w,
const cudnnTensorDescriptor_t *yDesc,
void *y,
const cudnnTensorDescriptor_t hyDesc,
void *hy,
const cudnnTensorDescriptor_t cyDesc,
void *cy,
const float findIntensity,
const int requestedAlgoCount,
int *returnedAlgoCount,
cudnnAlgorithmPerformance_t *perfResults,
void *workspace,
size_t workSpaceSizeInBytes);
/* Sequence data descriptor */
typedef enum {
CUDNN_SEQDATA_TIME_DIM = 0, /* index in time */
CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
CUDNN_SEQDATA_BEAM_DIM = 2, /* index in beam */
CUDNN_SEQDATA_VECT_DIM = 3 /* index in vector */
} cudnnSeqDataAxis_t;
struct cudnnSeqDataStruct;
typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t;
#define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */
cudnnStatus_t CUDNNWINAPI
cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);
cudnnStatus_t CUDNNWINAPI
cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);
cudnnStatus_t CUDNNWINAPI
cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
cudnnDataType_t dataType,
int nbDims,
const int dimA[],
const cudnnSeqDataAxis_t axes[],
size_t seqLengthArraySize,
const int seqLengthArray[],
void *paddingFill);
cudnnStatus_t CUDNNWINAPI
cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
cudnnDataType_t *dataType,
int *nbDims,
int nbDimsRequested,
int dimA[],
cudnnSeqDataAxis_t axes[],
size_t *seqLengthArraySize,
size_t seqLengthSizeRequested,
int seqLengthArray[],
void *paddingFill);
/* Multihead Attention */
/* Legacy type for backward compatibility */
typedef unsigned cudnnAttnQueryMap_t;
/*
* Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
* Use the bitwise OR operator to combine several settings listed below. Additional
* minor options can be added here w/o changing or introducing new API functions.
*/
#define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 /* multiple Q-s map to a single (K,V) set when beam size > 1 */
#define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
#define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 /* no biases in attention input and output projections */
#define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) /* use biases in attention input and output projections */
struct cudnnAttnStruct;
typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t;
cudnnStatus_t CUDNNWINAPI
cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);
cudnnStatus_t CUDNNWINAPI
cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);
cudnnStatus_t CUDNNWINAPI
cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
unsigned attnMode,
int nHeads,
double smScaler,
cudnnDataType_t dataType,
cudnnDataType_t computePrec,
cudnnMathType_t mathType,
cudnnDropoutDescriptor_t attnDropoutDesc,
cudnnDropoutDescriptor_t postDropoutDesc,
int qSize,
int kSize,
int vSize,
int qProjSize,
int kProjSize,
int vProjSize,
int oProjSize,
int qoMaxSeqLength,
int kvMaxSeqLength,
int maxBatchSize,
int maxBeamSize);
cudnnStatus_t CUDNNWINAPI
cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
unsigned *attnMode,
int *nHeads,
double *smScaler,
cudnnDataType_t *dataType,
cudnnDataType_t *computePrec,
cudnnMathType_t *mathType,
cudnnDropoutDescriptor_t *attnDropoutDesc,
cudnnDropoutDescriptor_t *postDropoutDesc,
int *qSize,
int *kSize,
int *vSize,
int *qProjSize,
int *kProjSize,
int *vProjSize,
int *oProjSize,
int *qoMaxSeqLength,
int *kvMaxSeqLength,
int *maxBatchSize,
int *maxBeamSize);
cudnnStatus_t CUDNNWINAPI
cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
const cudnnAttnDescriptor_t attnDesc,
size_t *weightSizeInBytes,
size_t *workSpaceSizeInBytes,
size_t *reserveSpaceSizeInBytes);
typedef enum {
CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
CUDNN_MH_ATTN_Q_BIASES = 4, /* input projection bias tensor for 'queries' */
CUDNN_MH_ATTN_K_BIASES = 5, /* input projection bias for 'keys' */
CUDNN_MH_ATTN_V_BIASES = 6, /* input projection bias for 'values' */
CUDNN_MH_ATTN_O_BIASES = 7, /* output projection biases */
} cudnnMultiHeadAttnWeightKind_t;
#define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */
cudnnStatus_t CUDNNWINAPI
cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
const cudnnAttnDescriptor_t attnDesc,
cudnnMultiHeadAttnWeightKind_t wKind,
size_t weightSizeInBytes,
const void *weights,
cudnnTensorDescriptor_t wDesc,
void **wAddr);
cudnnStatus_t CUDNNWINAPI
cudnnMultiHeadAttnForward(cudnnHandle_t handle,
const cudnnAttnDescriptor_t attnDesc,
int currIdx,
const int loWinIdx[],
const int hiWinIdx[],
const int devSeqLengthsQO[],
const int devSeqLengthsKV[],
const cudnnSeqDataDescriptor_t qDesc,
const void *queries,
const void *residuals,
const cudnnSeqDataDescriptor_t kDesc,
const void *keys,
const cudnnSeqDataDescriptor_t vDesc,
const void *values,
const cudnnSeqDataDescriptor_t oDesc,
void *out,
size_t weightSizeInBytes,
const void *weights,
size_t workSpaceSizeInBytes,
void *workSpace,
size_t reserveSpaceSizeInBytes,
void *reserveSpace);
/*
* \brief Cross-library version checker.
* This function is implemented differently in each sub-library. Each sublib
* checks whether its own version matches that of its dependencies.
* \returns CUDNN_STATUS_SUCCESS if the version check passes,
* CUDNN_STATUS_VERSION_MISMATCH if the versions are inconsistent.
*/
cudnnStatus_t CUDNNWINAPI
cudnnAdvInferVersionCheck(void);
#if defined(__cplusplus)
}
#endif
#endif /* CUDNN_ADV_INFER_H_ */