UnetDecoder¶

class UnetDecoder[source]¶

UNet decoder with skip connections.

This class implements the decoder portion of the UNet architecture. It reconstructs high-resolution feature maps from encoder outputs using multiple decoder blocks with skip connections.

center¶

Center block (currently Identity).

Type:

nn.Module

blocks¶

List of decoder blocks for upsampling and feature fusion.

Type:

nn.ModuleList

Example

>>> decoder = UnetDecoder()
>>> # Generate dummy feature maps for testing
>>> features = [
...     torch.randn(1, 32, 112, 112),
...     torch.randn(1, 24, 56, 56),
...     torch.randn(1, 40, 28, 28),
...     torch.randn(1, 112, 14, 14),
...     torch.randn(1, 320, 7, 7)
... ]
>>> output = decoder(features)
>>> output.shape
... torch.Size([1, 16, 224, 224])

Initialize UnetDecoder.

Sets up the decoder blocks with skip connections for UNet architecture.

Methods

forward

Forward pass through UNet decoder.

Attributes

training

forward(features)[source]¶

Forward pass through UNet decoder.

Reconstructs high-resolution feature maps from encoder outputs using skip connections and multiple decoder blocks.

Parameters:
Returns:

Decoded output tensor with spatial resolution restored.

Return type:

torch.Tensor