# What’s the difference between torch.stack() and torch.cat() functions?

Posted on

### Question :

What’s the difference between torch.stack() and torch.cat() functions?

OpenAI’s REINFORCE and actor-critic example for reinforcement learning has the following code:

``````policy_loss = torch.cat(policy_loss).sum()
``````
``````loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
``````

One is using `torch.cat`, the other uses `torch.stack`.

As far as my understanding goes, the doc doesn’t give any clear distinction between them.

I would be happy to know the differences between the functions.

`stack`

Concatenates sequence of tensors along a new dimension.

`cat`

Concatenates the given sequence of seq tensors in the given dimension.

So if `A` and `B` are of shape (3, 4), `torch.cat([A, B], dim=0)` will be of shape (6, 4) and `torch.stack([A, B], dim=0)` will be of shape (2, 3, 4).

The original answer lacks a good example that is self-contained so here it goes:

``````# %%
import torch
# stack vs cat
# cat "extends" a list in the given dimension e.g. adds more rows or columns
x = torch.randn(2, 3)
print(f'{x.size()}')
# add more rows (thus increasing the dimensionality of the column space to 2 -> 6)
xnew_from_cat = torch.cat((x, x, x), 0)
print(f'{xnew_from_cat.size()}')
# add more columns (thus increasing the dimensionality of the row space to 3 -> 9)
xnew_from_cat = torch.cat((x, x, x), 1)
print(f'{xnew_from_cat.size()}')
print()
# stack serves the same role as append in lists. i.e. it doesn't change the original
# vector space but instead adds a new index to the new tensor, so you retain the ability
# get the original tensor you added to the list by indexing in the new dimension
xnew_from_stack = torch.stack((x, x, x, x), 0)
print(f'{xnew_from_stack.size()}')
xnew_from_stack = torch.stack((x, x, x, x), 1)
print(f'{xnew_from_stack.size()}')
xnew_from_stack = torch.stack((x, x, x, x), 2)
print(f'{xnew_from_stack.size()}')
# default appends at the from
xnew_from_stack = torch.stack((x, x, x, x))
print(f'{xnew_from_stack.size()}')
print('I like to think of xnew_from_stack as a "tensor list"" that you can pop from the front')
``` Posted in Discuss Post navigation Previous post Python circular importing?Next post Django – taking values from POST request ```
``` ```
``` Leave a Reply Your email address will not be published. Required fields are marked *Comment Save my name, email, and website in this browser for the next time I comment. Δdocument.getElementById( "ak_js_1" ).setAttribute( "value", ( new Date() ).getTime() ); ```