Visible to Intel only — GUID: GUID-8968B514-A322-43BC-9A73-CD7C622DE1EE
Visible to Intel only — GUID: GUID-8968B514-A322-43BC-9A73-CD7C622DE1EE
Batch Normalization
General
The batch normalization primitive performs a forward or backward batch normalization operation on tensors with number of dimensions equal to 2 or more.
Forward
The batch normalization operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Naming Conventions.
where
are optional scale and shift for a channel (see dnnl_use_scale and dnnl_use_shift flags),
are mean and variance for a channel (see dnnl_use_global_stats flag), and
is a constant to improve numerical stability.
Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:
,
.
The and tensors are considered learnable.
In training mode, the primitive also optionally supports:
fusion with ReLU activation with zero negative slope applied to the result (see dnnl_fuse_norm_relu flag).
fusion with binary addition and ReLU activation with zero negative slope applied to the result (see dnnl_fuse_norm_add_relu flag).
- The batch normalization primitive computes population mean and variance and not the sample or unbiased versions that are typically used to compute running mean and variance.
Using the mean and variance computed by the batch normalization primitive, running mean and variance and can be computed as
Difference Between Forward Training and Forward Inference
If mean and variance are computed at runtime (i.e., dnnl_use_global_stats is not set), they become outputs for the propagation kind dnnl_forward_training (because they would be required during the backward propagation) and are not exposed for the propagation kind dnnl_forward_inference.
If batch normalization is created with ReLU fusion (i.e., dnnl_fuse_norm_relu or dnnl_fuse_norm_add_relu are set), for the propagation kind dnnl_forward_training the primitive would produce a workspace memory as one extra output. This memory is required to compute the backward propagation. When the primitive is executed with propagation kind dnnl_forward_inference, the workspace is not produced. Behavior would be the same as creating a batch normalization primitive with ReLU as a post-op (see section below).
Backward
The backward propagation computes , , and based on , , , , and .
The tensors marked with an asterisk are used only when the primitive is configured to use and (i.e., dnnl_use_scale or dnnl_use_shift are set).
Execution Arguments
Depending on the flags and propagation kind, the batch normalization primitive requires different inputs and outputs. For clarity, a summary is shown below.
Flags |
||||
---|---|---|---|---|
Inputs : Outputs : |
Inputs : Outputs : , , |
Inputs : , , , Outputs : |
Same as for dnnl_backward |
|
Inputs : , , Outputs : |
Inputs : , , Outputs : |
Inputs : , , , Outputs : |
Same as for dnnl_backward |
|
Inputs : , Outputs : |
Inputs : , Outputs : , , |
Inputs : , , , , Outputs : , |
Not supported |
|
Inputs : , Outputs : |
Inputs : , Outputs : , , |
Inputs : , , , Outputs : , |
Not supported |
|
Inputs : , , , , Outputs : |
Inputs : , , , , Outputs : |
Inputs : , , , , Outputs : , , |
Not supported |
|
flags | dnnl_fuse_norm_relu |
Inputs : same as with flagsOutputs : same as with flags |
Inputs : same as with flagsOutputs : same as with flags , Workspace |
Inputs : same as with flags , WorkspaceOutputs : same as with flags |
Same as for dnnl_backward if flags do not contain dnnl_use_scale or dnnl_use_shift ; not supported otherwise |
flags | dnnl_fuse_norm_add_relu |
Inputs : same as with flags and for fused binary addition Outputs : same as with flags |
Inputs : same as with flags and for fused binary addition Outputs : same as with flags , Workspace |
Inputs : same as with flags , WorkspaceOutputs : same as with flags and for fused binary addition |
Same as for dnnl_backward if flags do not contain dnnl_use_scale or dnnl_use_shift ; not supported otherwise |
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive Input/Output |
Execution Argument Index |
---|---|
DNNL_ARG_SRC |
|
DNNL_ARG_SRC_1 |
|
DNNL_ARG_SCALE |
|
DNNL_ARG_SHIFT |
|
mean ( ) |
DNNL_ARG_MEAN |
variance ( ) |
DNNL_ARG_VARIANCE |
DNNL_ARG_DST |
|
workspace |
DNNL_ARG_WORKSPACE |
DNNL_ARG_DIFF_DST |
|
DNNL_ARG_DIFF_SRC |
|
DNNL_ARG_DIFF_SRC_1 |
|
DNNL_ARG_DIFF_SCALE |
|
DNNL_ARG_DIFF_SHIFT |
Implementation Details
General Notes
The different flavors of the primitive are partially controlled by the flags parameter that is passed to the primitive descriptor creation function (e.g., dnnl::batch_normalization_forward::primitive_desc()). Multiple flags can be set using the bitwise OR operator (|).
For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the dnnl_use_global_stats flag. For the backward propagation, the mean and variance are always input parameters.
Both forward and backward propagation support in-place operations, meaning that can be used as input and output for forward propagation, and can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. Note, however, that backward propagation requires original , hence the corresponding forward propagation should not be performed in-place.
As mentioned above, the batch normalization primitive can be fused with binary addition and ReLU activation (dnnl_fuse_norm_add_relu). In this case:
on the forward propagation the primitive has one additional input, , that should have memory descriptor equal to primitive dst_desc memory descriptor.
on the backward propagation the primitive has one additional output, , that should have memory descriptor equal to primitive diff_dst_desc memory descriptor.
As mentioned above, the batch normalization primitive can be fused with ReLU activation (dnnl_fuse_norm_relu) or binary addition and ReLU activation (dnnl_fuse_norm_add_relu) even in the training mode. In this case, on the forward propagation the primitive has one additional output, workspace, that should be passed during the backward propagation.
Data Type Support
The operation supports the following combinations of data types:
Propagation |
Source / Destination |
Mean / Variance / ScaleShift |
---|---|---|
forward / backward |
f32, bf16, f16 |
f32 |
forward |
s8 |
f32 |
Data Representation
Mean and Variance
The mean () and variance () are separate 1D tensors of size .
The format of the corresponding memory object must be dnnl_x (dnnl_a).
Scale and Shift
If dnnl_use_scale or dnnl_use_shift are used, the scale () and shift () are separate 1D tensors of shape .
The format of the corresponding memory object must be dnnl_a.
Source, Destination, and Their Gradients
Like other CNN primitives, the batch normalization primitive expects data to be tensor.
The batch normalization primitive is optimized for the following memory formats:
Spatial |
Logical tensor |
Implementations optimized for memory formats |
---|---|---|
0D |
NC |
|
1D |
NCW |
|
2D |
NCHW |
dnnl_nchw ( dnnl_abcd ), dnnl_nhwc ( dnnl_acdb ), optimized^ |
3D |
NCDHW |
dnnl_ncdhw ( dnnl_abcde ), dnnl_ndhwc ( dnnl_acdeb ), optimized^ |
Here optimized^ means the format that comes out of any preceding compute-intensive primitive.
Post-Ops and Attributes
Post-ops and attributes enable you to modify the behavior of the batch normalization primitive by chaining certain operations after the batch normalization operation. The following post-ops are supported by batch normalization primitives:
Propagation |
Type |
Operation |
Description |
---|---|---|---|
forward |
post-op |
eltwise |
Applies an Eltwise operation to the result (currently only dnnl_eltwise_relu algorithm is supported) |
Implementation Limitations
Refer to Data Types for limitations related to data types support.
For the data types that have forward propagation support only, mean and variance must be provided by a user (i.e., dnnl_use_global_stats is set).
CPU implementations do not support the fusion with binary addition and ReLU activation (dnnl_fuse_norm_add_relu).
Performance Tips
For backward propagation, use the same memory format for , , and (the format of the and are always the same because of the API). Different formats are functionally supported but lead to highly suboptimal performance.
Use in-place operations whenever possible (see caveats in General Notes).
GPU implementations support an experimental algorithm with single pass statistics calculations. Please review experimental features for more details.
Examples
Batch Normalization Primitive Example
This C++ API example demonstrates how to create and execute a Batch Normalization primitive in forward training propagation mode.
Key optimizations included in this example:
In-place primitive execution;
Source memory format for an optimized primitive implementation;
Fused post-ops via operation descriptor flags;