UnetDecoder¶
- class UnetDecoder[source]¶
UNet decoder with skip connections.
This class implements the decoder portion of the UNet architecture. It reconstructs high-resolution feature maps from encoder outputs using multiple decoder blocks with skip connections.
- center¶
Center block (currently Identity).
- Type:
nn.Module
- blocks¶
List of decoder blocks for upsampling and feature fusion.
- Type:
nn.ModuleList
Example
>>> decoder = UnetDecoder() >>> # Generate dummy feature maps for testing >>> features = [ ... torch.randn(1, 32, 112, 112), ... torch.randn(1, 24, 56, 56), ... torch.randn(1, 40, 28, 28), ... torch.randn(1, 112, 14, 14), ... torch.randn(1, 320, 7, 7) ... ] >>> output = decoder(features) >>> output.shape ... torch.Size([1, 16, 224, 224])
Initialize UnetDecoder.
Sets up the decoder blocks with skip connections for UNet architecture.
Methods
Forward pass through UNet decoder.
Attributes
training- forward(features)[source]¶
Forward pass through UNet decoder.
Reconstructs high-resolution feature maps from encoder outputs using skip connections and multiple decoder blocks.
- Parameters:
features (list[torch.Tensor]) – List of feature maps from the encoder, ordered from shallow to deep.
self (UnetDecoder)
- Returns:
Decoded output tensor with spatial resolution restored.
- Return type: