Alex Shapiro

Squeeze & Unsqueeze

I see a lot of Pytorch code that calls unsqueeze. What does it do? Is there a squeeze? Yes.

Take a 10 x 1 x 10 tensor:

>>> z = torch.zeros(10, 1, 10)
>>> z.shape
torch.Size([10, 1, 10])

Squeezing dimension 0 does nothing: a dimension must have 1 element to be squeezed.

>>> z.squeeze(0).shape
torch.Size([10, 1, 10])

Dimension 1 has 1 element so it can be sqeezed:

>>> z.squeeze(1).shape
torch.Size([10, 10])

Squeezing dimension N removes it from the tensor without changing data. It effectively eliminates the array wrapper holding dimension N+1.

Unsqueeze does the opposite: it wraps dimension N inside a 1-element array:

>>> z.unsqueeze(0).shape
torch.Size([1, 10, 1, 10])

Unsqueeze can be called on any dimension because wrapping is legal on any dimension, not just on 1-element dimensions.