How to visualize RNN/LSTM gradients in Keras/TensorFlow?

Posted on

Question :

How to visualize RNN/LSTM gradients in Keras/TensorFlow?

I’ve come across research publications and Q&A’s discussing a need for inspecting RNN gradients per backpropagation through time (BPTT) – i.e., gradient for each timestep. The main use is introspection: how do we know if an RNN is learning long-term dependencies? A question of its own topic, but the most important insight is gradient flow:

  • If a non-zero gradient flows through every timestep, then every timestep contributes to learning – i.e., resultant gradients stem from accounting for every input timestep, so the entire sequence influences weight updates
  • Per above, an RNN no longer ignores portions of long sequences, and is forced to learn from them

… but how do I actually visualize these gradients in Keras / TensorFlow? Some related answers are in the right direction, but they seem to fail for bidirectional RNNs, and only show how to get a layer’s gradients, not how to meaningfully visualize them (the output is a 3D tensor – how do I plot it?)

Answer #1:

Gradients can be fetched w.r.t. weights or outputs – we’ll be needing latter. Further, for best results, an architecture-specific treatment is desired. Below code & explanations cover every possible case of a Keras/TF RNN, and should be easily expandable to any future API changes.


Completeness: code shown is a simplified version – the full version can be found at my repository, See RNN (this post included w/ bigger images); included are:

  • Greater visual custsomizability
  • Docstrings explaining all functionality
  • Support for Eager, Graph, TF1, TF2, and from keras & from tf.keras
  • Activations visualization
  • Weights gradients visualization (coming soon)
  • Weights visualization (coming soon)

I/O dimensionalities (all RNNs):

  • Input: (batch_size, timesteps, channels) – or, equivalently, (samples, timesteps, features)
  • Output: same as Input, except:
    • channels/features is now the # of RNN units, and:
    • return_sequences=True –> timesteps_out = timesteps_in (output a prediction for each input timestep)
    • return_sequences=False –> timesteps_out = 1 (output prediction only at the last timestep processed)

Visualization methods:

  • 1D plot grid: plot gradient vs. timesteps for each of the channels
  • 2D heatmap: plot channels vs. timesteps w/ gradient intensity heatmap
  • 0D aligned scatter: plot gradient for each channel per sample
  • histogram: no good way to represent “vs. timesteps” relations
  • One sample: do each of above for a single sample
  • Entire batch: do each of above for all samples in a batch; requires careful treatment
# for below examples
grads = get_rnn_gradients(model, x, y, layer_idx=1) # return_sequences=True
grads = get_rnn_gradients(model, x, y, layer_idx=2) # return_sequences=False

EX 1: one sample, uni-LSTM, 6 unitsreturn_sequences=True, trained for 20 iterations
show_features_1D(grads[0], n_rows=2)

  • Note: gradients are to be read right-to-left, as they’re computed (from last timestep to first)
  • Rightmost (latest) timesteps consistently have a higher gradient
  • Vanishing gradient: ~75% of leftmost timesteps have a zero gradient, indicating poor time dependency learning

enter image description here


EX 2: all (16) samples, uni-LSTM, 6 unitsreturn_sequences=True, trained for 20 iterations
show_features_1D(grads, n_rows=2)
show_features_2D(grads, n_rows=4, norm=(-.01, .01))

  • Each sample shown in a different color (but same color per sample across channels)
  • Some samples perform better than one shown above, but not by much
  • The heatmap plots channels (y-axis) vs. timesteps (x-axis); blue=-0.01, red=0.01, white=0 (gradient values)

enter image description here
enter image description here


EX 3: all (16) samples, uni-LSTM, 6 unitsreturn_sequences=True, trained for 200 iterations
show_features_1D(grads, n_rows=2)
show_features_2D(grads, n_rows=4, norm=(-.01, .01))

  • Both plots show the LSTM performing clearly better after 180 additional iterations
  • Gradient still vanishes for about half the timesteps
  • All LSTM units better capture time dependencies of one particular sample (blue curve, all plots) – which we can tell from the heatmap to be the first sample. We can plot that sample vs. other samples to try to understand the difference

enter image description here
enter image description here


EX 4: 2D vs. 1D, uni-LSTM: 256 units, return_sequences=True, trained for 200 iterations
show_features_1D(grads[0])
show_features_2D(grads[:, :, 0], norm=(-.0001, .0001))

  • 2D is better suited for comparing many channels across few samples
  • 1D is better suited for comparing many samples across a few channels

enter image description here


EX 5: bi-GRU, 256 units (512 total)return_sequences=True, trained for 400 iterations
show_features_2D(grads[0], norm=(-.0001, .0001), reflect_half=True)

  • Backward layer’s gradients are flipped for consistency w.r.t. time axis
  • Plot reveals a lesser-known advantage of Bi-RNNs – information utility: the collective gradient covers about twice the data. However, this isn’t free lunch: each layer is an independent feature extractor, so learning isn’t really complemented
  • Lower norm for more units is expected, as approx. the same loss-derived gradient is being distributed across more parameters (hence the squared numeric average is less)


EX 6: 0D, all (16) samples, uni-LSTM, 6 unitsreturn_sequences=False, trained for 200 iterations
show_features_0D(grads)

  • return_sequences=False utilizes only the last timestep’s gradient (which is still derived from all timesteps, unless using truncated BPTT), requiring a new approach
  • Plot color-codes each RNN unit consistently across samples for comparison (can use one color instead)
  • Evaluating gradient flow is less direct and more theoretically involved. One simple approach is to compare distributions at beginning vs. later in training: if the difference isn’t significant, the RNN does poorly in learning long-term dependencies


EX 7: LSTM vs. GRU vs. SimpleRNN, unidir, 256 unitsreturn_sequences=True, trained for 250 iterations
show_features_2D(grads, n_rows=8, norm=(-.0001, .0001), show_xy_ticks=[0,0], show_title=False)

  • Note: the comparison isn’t very meaningful; each network thrives w/ different hyperparameters, whereas same ones were used for all. LSTM, for one, bears the most parameters per unit, drowning out SimpleRNN
  • In this setup, LSTM definitively stomps GRU and SimpleRNN

enter image description here


Visualization functions:

def get_rnn_gradients(model, input_data, labels, layer_idx=None, layer_name=None, 
                      sample_weights=None):
    if layer is None:
        layer = _get_layer(model, layer_idx, layer_name)

    grads_fn = _make_grads_fn(model, layer, mode)
    sample_weights = sample_weights or np.ones(len(input_data))
    grads = grads_fn([input_data, sample_weights, labels, 1])

    while type(grads) == list:
        grads = grads[0]
    return grads

def _make_grads_fn(model, layer):
    grads = model.optimizer.get_gradients(model.total_loss, layer.output)
    return K.function(inputs=[model.inputs[0],  model.sample_weights[0],
                              model._feed_targets[0], K.learning_phase()], outputs=grads) 

def _get_layer(model, layer_idx=None, layer_name=None):
    if layer_idx is not None:
        return model.layers[layer_idx]

    layer = [layer for layer in model.layers if layer_name in layer.name]
    if len(layer) > 1:
        print("WARNING: multiple matching layer names found; "
              + "picking earliest")
    return layer[0]


def show_features_1D(data, n_rows=None, label_channels=True,
                     equate_axes=True, max_timesteps=None, color=None,
                     show_title=True, show_borders=True, show_xy_ticks=[1,1], 
                     title_fontsize=14, channel_axis=-1, 
                     scale_width=1, scale_height=1, dpi=76):
    def _get_title(data, show_title):
        if len(data.shape)==3:
            return "((Gradients vs. Timesteps) vs. Samples) vs. Channels"
        else:        
            return "((Gradients vs. Timesteps) vs. Channels"

    def _get_feature_outputs(data, subplot_idx):
        if len(data.shape)==3:
            feature_outputs = []
            for entry in data:
                feature_outputs.append(entry[:, subplot_idx-1][:max_timesteps])
            return feature_outputs
        else:
            return [data[:, subplot_idx-1][:max_timesteps]]

    if len(data.shape)!=2 and len(data.shape)!=3:
        raise Exception("`data` must be 2D or 3D")

    if len(data.shape)==3:
        n_features = data[0].shape[channel_axis]
    else:
        n_features = data.shape[channel_axis]
    n_cols = int(n_features / n_rows)

    if color is None:
        n_colors = len(data) if len(data.shape)==3 else 1
        color = [None] * n_colors

    fig, axes = plt.subplots(n_rows, n_cols, sharey=equate_axes, dpi=dpi)
    axes = np.asarray(axes)

    if show_title:
        title = _get_title(data, show_title)
        plt.suptitle(title, weight='bold', fontsize=title_fontsize)
    fig.set_size_inches(12*scale_width, 8*scale_height)

    for ax_idx, ax in enumerate(axes.flat):
        feature_outputs = _get_feature_outputs(data, ax_idx)
        for idx, feature_output in enumerate(feature_outputs):
            ax.plot(feature_output, color=color[idx])

        ax.axis(xmin=0, xmax=len(feature_outputs[0]))
        if not show_xy_ticks[0]:
            ax.set_xticks([])
        if not show_xy_ticks[1]:
            ax.set_yticks([])
        if label_channels:
            ax.annotate(str(ax_idx), weight='bold',
                        color='g', xycoords='axes fraction',
                        fontsize=16, xy=(.03, .9))
        if not show_borders:
            ax.set_frame_on(False)

    if equate_axes:
        y_new = []
        for row_axis in axes:
            y_new += [np.max(np.abs([col_axis.get_ylim() for
                                     col_axis in row_axis]))]
        y_new = np.max(y_new)
        for row_axis in axes:
            [col_axis.set_ylim(-y_new, y_new) for col_axis in row_axis]
    plt.show()


def show_features_2D(data, n_rows=None, norm=None, cmap='bwr', reflect_half=False,
                     timesteps_xaxis=True, max_timesteps=None, show_title=True,
                     show_colorbar=False, show_borders=True, 
                     title_fontsize=14, show_xy_ticks=[1,1],
                     scale_width=1, scale_height=1, dpi=76):
    def _get_title(data, show_title, timesteps_xaxis, vmin, vmax):
        if timesteps_xaxis:
            context_order = "(Channels vs. %s)" % "Timesteps"
        if len(data.shape)==3:
            extra_dim = ") vs. Samples"
            context_order = "(" + context_order
        return "{} vs. {}{} -- norm=({}, {})".format(context_order, "Timesteps",
                                                     extra_dim, vmin, vmax)

    vmin, vmax = norm or (None, None)
    n_samples = len(data) if len(data.shape)==3 else 1
    n_cols = int(n_samples / n_rows)

    fig, axes = plt.subplots(n_rows, n_cols, dpi=dpi)
    axes = np.asarray(axes)

    if show_title:
        title = _get_title(data, show_title, timesteps_xaxis, vmin, vmax)
        plt.suptitle(title, weight='bold', fontsize=title_fontsize)

    for ax_idx, ax in enumerate(axes.flat):
        img = ax.imshow(data[ax_idx], cmap=cmap, vmin=vmin, vmax=vmax)
        if not show_xy_ticks[0]:
            ax.set_xticks([])
        if not show_xy_ticks[1]:
            ax.set_yticks([])
        ax.axis('tight')
        if not show_borders:
            ax.set_frame_on(False)

    if show_colorbar:
        fig.colorbar(img, ax=axes.ravel().tolist())

    plt.gcf().set_size_inches(8*scale_width, 8*scale_height)
    plt.show()


def show_features_0D(data, marker='o', cmap='bwr', color=None,
                     show_y_zero=True, show_borders=False, show_title=True,
                     title_fontsize=14, markersize=15, markerwidth=2,
                     channel_axis=-1, scale_width=1, scale_height=1):
    if color is None:
        cmap = cm.get_cmap(cmap)
        cmap_grad = np.linspace(0, 256, len(data[0])).astype('int32')
        color = cmap(cmap_grad)
        color = np.vstack([color] * data.shape[0])
    x = np.ones(data.shape) * np.expand_dims(np.arange(1, len(data) + 1), -1)

    if show_y_zero:
        plt.axhline(0, color='k', linewidth=1)
    plt.scatter(x.flatten(), data.flatten(), marker=marker,
                s=markersize, linewidth=markerwidth, color=color)
    plt.gca().set_xticks(np.arange(1, len(data) + 1), minor=True)
    plt.gca().tick_params(which='minor', length=4)

    if show_title:
        plt.title("(Gradients vs. Samples) vs. Channels",
                  weight='bold', fontsize=title_fontsize)
    if not show_borders:
        plt.box(None)
    plt.gcf().set_size_inches(12*scale_width, 4*scale_height)
    plt.show()

Full minimal example: see repository’s README


Bonus code:

  • How can I check weight/gate ordering without reading source code?
rnn_cell = model.layers[1].cell          # unidirectional
rnn_cell = model.layers[1].forward_layer # bidirectional; also `backward_layer`
print(rnn_cell.__dict__)

For more convenient code, see repo’s rnn_summary


Bonus fact: if you run above on GRU, you may notice that bias has no gates; why so? From docs:

There are two variants. The default one is based on 1406.1078v3 and has reset gate applied to hidden state before matrix multiplication. The other one is based on original 1406.1078v1 and has the order reversed.

The second variant is compatible with CuDNNGRU (GPU-only) and allows inference on CPU. Thus it has separate biases for kernel and recurrent_kernel. Use ‘reset_after’=True and recurrent_activation=’sigmoid’.

Answered By: OverLordGoldDragon

Leave a Reply

Your email address will not be published.