Compressing GNMT Models


In recent years, Neural Machine Translation[1],[2] (NMT) has become the most common approach for machine translation and other Sequence-to-Sequence (seq2seq) learning tasks such as automatic summarization and grammatical error correction. NMT models are defined by their ability to learn a mapping from an input text to an appropriate output text. Common seq2seq model architectures usually consist of two recurrent neural networks (RNNs) with an attention mechanism. The first network is the encoder, which maps the input sequence (from source language) to feature vectors, and the second network is the decoder, which generates the translated text (to target language) using the features generated by the encoder. An attention mechanism[3] allows the decoder to weight different regions of the input sentence according to the relevance of the source words to the target sentence. This mechanism is done during the decoding process and greatly improves the model accuracy when the input sequence is long. Recently, Vaswani et al. proposed the Transformer2, an architecture that is entirely based on an attention mechanism to model the dependencies between input and output sequences.

An example of the most common NMT architecture can be seen in figure 1.

Figure 1: Common Encoder-Decoder architecture with Attention

Pruning Neural Networks

Neural networks are typically over-parameterized. Recent works[4],[5] proposed methods for pruning NMT models, that is, reducing the number of non-zero parameters of a model. In these works, a pruning process is applied during the training stage, which resulted in highly sparse models with minimal or no loss of accuracy. In the method that Zhu and Gupta proposed[5], each layer of the model is assigned a binary mask that represents the sparsity pattern of that layer. The weights of the layers are obtained, both in forward and backward passes, by multiplying the layer weights with the layer’s mask. The masks are updated according to a pruning policy during training. At each pruning step, a temporal target sparsity is calculated and a new sparsity threshold is derived. Finally, the masks are updated to zero-out all the weights whose magnitude is below the new threshold. The pruning policy for the final model sparsity is dictated by the following formula:

where st is the temporal (current step) model sparsity, t is the global training step, sf is the final model sparsity, t0 is the global step from which the pruning starts and t1 is the global step when pruning ends. Figure 2 depicts an example for a pruning process of a network, where the target sparsity level is set to reach 97.5%. For more details on the pruning process see Michael Zhu and Suyog Gupta’s paper on the efficacy of pruning for model compression.

Other works[6],[7] propose structured pruning where low-energy blocks are zeroed out. Masking blocks instead of single elements during the pruning process may lead to additional accuracy loss depending on the block size and the size of the network.

Figure 2: An example for a pruning policy. The temporal sparsity level is denoted as st, target sparsity as sf, pruning process starts at step t0 and end on step t1.

Sparse GNMT Model

We integrated the pruning mechanism proposed by Zhu and Gupta[5]  into Google’s NMT architecture[8]. The model consists of an encoder with 1 bidirectional LSTM followed by 3 LSTM layers, a decoder with 4 LSTM layers followed by a linear layer with softmax (projection layer). Both the RNN hidden state size and the embedding vector size is 1024.

NLP Architect release 0.3 includes our implementation for creating sparse GNMT models and two sparse pre-trained models (Table 1). The implementation has an option for advanced controls for training a sparse NMT model, such as, the option to prune any of the following network components – the embedding layers, LSTM cells, or final linear layer (classifier). The models were trained to translate German to English using Europarl-v7[9], Common Crawl and News Commentary 11 provided by Shared Task: Machine Translation of News as the training dataset. The dataset was pre-processed with Byte Pair Encoding[10] (10) to reduce the size of the vocabulary and to handle out-of-vocabulary words. For testing the accuracy of our models we used the newstest2015 test set provided by the Shared Task. For full details of our published models see the model documentation.

Post Training Weight Quantization

The weights of the pretrained GNMT model are represented in 32bit Floating-point format. We show how to further compress the highly sparse models by uniform quantization of the weights to 8bit Integer format, gaining a further compression ratio of 4x with negligible accuracy loss. To implement the weight quantization, we used the TensorFlow API. We added an option to run inference with the sparse and quantized models. Before inference, we de-quantize the compressed int8 wights back to fp32.

For full details on how to train sparse GNMT model and comprehensive details on the model visit the model’s page on NLP Architect.

Results and Summary

Using the pruning and quantization techniques as described above, we successfully trained several sparse models with minimal accuracy loss. We summarize the results in Table 1.

Model Sparsity1 BLEU2 Non-Zero Parameters Data Type
Baseline 0% 29.9 ~210M Float32
Sparse 90% 28.4 ~22M Float32
2×2 Block Sparse 90% 27.8 ~22M Float32
Quantized Sparse 90% 28.43 ~22M Integer8
Quantized 2×2 Block Sparse 90% 27.63 ~22M Integer8

Table 1: Sparse and quantized NMT model performance.
  1. The pruning is applied to embedding, decoder’s projection layer and all LSTM layers in both the encoder and decoder.
  2. BLEU score is measured using newstest2015 test set provided by the Shared Task.
  3. The accuracy of the quantized model was measure when we converted the 8 bits weights back to floating point during inference.

We hope that our compressed NMT models will encourage developers to apply pruning and quantization techniques in their models, and inspire future work on sparsity in deep neural networks and hardware-aware pruning methods.

We expect that the new compressed model will serve as a good baseline for the developer community to develop optimized software kernels that will leverage sparsity for efficient inference on limited-resource devices.