Skip to content

Документация моделей

models

Model architectures and components.

Classes

VQWithProjection

VQWithProjection(
    input_dim,
    codebook_size=512,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Vector Quantization (VQ-VAE) with projections

Uses EMA for codebook updates (no gradients needed for codebook) ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection (e.g., 2048 -> 64)
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Vector Quantization
    self.vq = VectorQuantize(
        dim=bottleneck_dim,
        codebook_size=codebook_size,
        decay=decay,  # EMA decay for codebook
        commitment_weight=commitment_weight  # Commitment loss weight
    )

    # Up projection (64 -> 2048)
    self.project_out = nn.Linear(bottleneck_dim, input_dim)

FSQWithProjection

FSQWithProjection(input_dim, levels=None)

Bases: BaseQuantizer

Finite Scalar Quantization (FSQ)

Quantization without codebook - each dimension quantized independently ~10 bits per vector at levels=[8,5,5,5]

Source code in embeddings_squeeze\models\quantizers.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(self, input_dim: int, levels: list = None):
    super().__init__(input_dim)
    if levels is None:
        levels = [8, 5, 5, 5]  # 8*5*5*5 = 1000 codes ≈ 2^10

    self.num_levels = len(levels)

    # Projection to quantization space
    self.project_in = nn.Linear(input_dim, self.num_levels)

    # FSQ quantization
    self.fsq = FSQ(levels=levels, dim=self.num_levels)

    # Projection back
    self.project_out = nn.Linear(self.num_levels, input_dim)

LFQWithProjection

LFQWithProjection(
    input_dim,
    codebook_size=512,
    entropy_loss_weight=0.1,
    diversity_gamma=0.1,
    spherical=False,
)

Bases: BaseQuantizer

Lookup-Free Quantization (LFQ)

Uses entropy loss for code diversity ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512,
    entropy_loss_weight: float = 0.1, 
    diversity_gamma: float = 0.1, 
    spherical: bool = False
):
    super().__init__(input_dim)
    # Quantization dimension = log2(codebook_size)
    self.quant_dim = int(math.log2(codebook_size))

    # Projection with normalization
    self.project_in = nn.Sequential(
        nn.Linear(input_dim, self.quant_dim),
        nn.LayerNorm(self.quant_dim)
    )

    # LFQ quantization
    self.lfq = LFQ(
        dim=self.quant_dim,
        codebook_size=codebook_size,
        entropy_loss_weight=entropy_loss_weight,
        diversity_gamma=diversity_gamma,
        spherical=spherical
    )

    # Projection back
    self.project_out = nn.Linear(self.quant_dim, input_dim)

ResidualVQWithProjection

ResidualVQWithProjection(
    input_dim,
    num_quantizers=4,
    codebook_size=256,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Residual Vector Quantization (RVQ)

Multi-level quantization - each level quantizes the residual of the previous 32 bits per vector at num_quantizers=4, codebook_size=256 (4*8 bits)

Source code in embeddings_squeeze\models\quantizers.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self, 
    input_dim: int, 
    num_quantizers: int = 4,
    codebook_size: int = 256, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Residual VQ
    self.residual_vq = ResidualVQ(
        dim=bottleneck_dim,
        num_quantizers=num_quantizers,  # Number of levels
        codebook_size=codebook_size,
        decay=decay,
        commitment_weight=commitment_weight
    )

    # Up projection
    self.project_out = nn.Linear(bottleneck_dim, input_dim)

BaseQuantizer

BaseQuantizer(input_dim)

Bases: Module

Base class for all quantizers

Source code in embeddings_squeeze\models\quantizers.py
14
15
16
def __init__(self, input_dim: int):
    super().__init__()
    self.input_dim = input_dim
Functions
quantize_spatial
quantize_spatial(features)

Quantize spatial features [B, C, H, W]

Parameters:

Name Type Description Default
features Tensor

Tensor of shape [B, C, H, W]

required

Returns:

Name Type Description
quantized

Quantized features [B, C, H, W]

loss

Quantization loss (scalar)

Source code in embeddings_squeeze\models\quantizers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def quantize_spatial(self, features: torch.Tensor):
    """
    Quantize spatial features [B, C, H, W]

    Args:
        features: Tensor of shape [B, C, H, W]

    Returns:
        quantized: Quantized features [B, C, H, W]
        loss: Quantization loss (scalar)
    """
    B, C, H, W = features.shape
    # Transform [B, C, H, W] -> [B, H*W, C]
    seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)

    # Quantize
    quantized, indices, loss = self.forward(seq)

    # Transform back [B, H*W, C] -> [B, C, H, W]
    quantized = quantized.reshape(B, H, W, C).permute(0, 3, 1, 2)

    # Handle loss (may be tensor with multiple elements)
    if isinstance(loss, torch.Tensor) and loss.numel() > 1:
        loss = loss.mean()

    return quantized, loss

DiceLoss

DiceLoss(smooth=1.0)

Bases: Module

Dice Loss for multi-class segmentation

Source code in embeddings_squeeze\models\losses.py
13
14
15
def __init__(self, smooth: float = 1.0):
    super().__init__()
    self.smooth = smooth

FocalLoss

FocalLoss(alpha=1.0, gamma=2.0, reduction='mean')

Bases: Module

Focal Loss for handling class imbalance (multi-class via CE per-pixel)

Source code in embeddings_squeeze\models\losses.py
40
41
42
43
44
def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction

CombinedLoss

CombinedLoss(
    ce_weight=1.0,
    dice_weight=1.0,
    focal_weight=0.5,
    class_weights=None,
)

Bases: Module

Combined loss: CE + Dice + Focal. Returns (total, ce, dice, focal).

Source code in embeddings_squeeze\models\losses.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, 
    ce_weight: float = 1.0, 
    dice_weight: float = 1.0, 
    focal_weight: float = 0.5, 
    class_weights=None
):
    super().__init__()
    # class_weights can be None or a tensor/list
    if class_weights is not None:
        # Leave tensor creation to forward (to place on correct device) but store raw
        self._class_weights = class_weights
    else:
        self._class_weights = None

    self.ce_weight = ce_weight
    self.dice_weight = dice_weight
    self.focal_weight = focal_weight

    # Instantiate component losses
    self.dice_loss = DiceLoss()
    self.focal_loss = FocalLoss()

SegmentationBackbone

SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()
Attributes
feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions
extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass

VQSqueezeModule

VQSqueezeModule(
    backbone,
    quantizer=None,
    num_classes=21,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    loss_type="ce",
    class_weights=None,
    add_adapter=False,
    feature_dim=2048,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for VQ compression training.

Features: - Multiple quantizer support (VQ, FSQ, LFQ, RVQ) - Adapter layers for fine-tuning frozen backbones - Advanced loss functions (CE, Dice, Focal, Combined) - Embedding extraction and saving

Source code in embeddings_squeeze\models\lightning_module.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    backbone: SegmentationBackbone,
    quantizer: Optional[nn.Module] = None,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    vq_loss_weight: float = 0.1,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    add_adapter: bool = False,
    feature_dim: int = 2048,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'quantizer', 'clearml_logger'])

    self.num_classes = num_classes
    self.learning_rate = learning_rate
    self.vq_loss_weight = vq_loss_weight
    self.loss_type = loss_type
    self.add_adapter = add_adapter
    self.feature_dim = feature_dim

    # Setup backbone with optional adapters
    self.backbone = backbone
    self._setup_backbone_with_adapters(feature_dim, add_adapter)

    # Quantizer (optional)
    self.quantizer = quantizer

    # Loss function
    self.criterion = self._init_loss(loss_type, class_weights)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger

    # Embedding storage (per-epoch, first batch only)
    self.embedding_dir = "embeddings"
    os.makedirs(self.embedding_dir, exist_ok=True)
    self._first_val_batch_features = None

    # UMAP visualization storage
    self._val_backbone_embeddings = []
    self._val_quantized_embeddings = []
Functions
forward
forward(images)

Forward pass through backbone + optional quantizer + decoder.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

quant_loss

Quantization loss (0 if no quantizer)

original_features

Extracted features (before quantization)

quantized_features

Features after quantization (same as original if no quantizer)

Source code in embeddings_squeeze\models\lightning_module.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(self, images):
    """
    Forward pass through backbone + optional quantizer + decoder.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
        quant_loss: Quantization loss (0 if no quantizer)
        original_features: Extracted features (before quantization)
        quantized_features: Features after quantization (same as original if no quantizer)
    """
    # Extract features
    features = self.backbone.extract_features(images, detach=self.feature_adapter is not None)

    # Apply adapter if present
    if self.feature_adapter is not None:
        features = features + self.feature_adapter(features)

    # Store original features for embedding extraction
    original_features = features

    # Quantize if quantizer is present
    quant_loss = torch.tensor(0.0, device=images.device)
    quantized_features = original_features  # Default to original if no quantizer
    if self.quantizer is not None:
        features, quant_loss = self.quantizer.quantize_spatial(features)
        quantized_features = features

    # Decode to segmentation logits
    output = self.backbone.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear', align_corners=False)

    return output, quant_loss, original_features, quantized_features
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\lightning_module.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, _, _ = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\lightning_module.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, backbone_features, quantized_features = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    # Accumulate embeddings for UMAP visualization
    self._val_backbone_embeddings.append(backbone_features.detach().cpu())
    self._val_quantized_embeddings.append(quantized_features.detach().cpu())

    # Save only first batch features for this epoch
    if batch_idx == 0:
        self._first_val_batch_features = backbone_features.detach().cpu()

    return loss
on_validation_epoch_start
on_validation_epoch_start()

Clear accumulated embeddings at the start of each validation epoch.

Source code in embeddings_squeeze\models\lightning_module.py
270
271
272
273
def on_validation_epoch_start(self):
    """Clear accumulated embeddings at the start of each validation epoch."""
    self._val_backbone_embeddings.clear()
    self._val_quantized_embeddings.clear()
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations and save embeddings.

Source code in embeddings_squeeze\models\lightning_module.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations and save embeddings."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    keys = [
        "train/loss", "val/loss", "train/iou", "val/iou",
        "train/precision", "val/precision", "train/recall", "val/recall",
        "train/f1", "val/f1"
    ]
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )

    # Generate UMAP visualizations on even epochs
    if self.current_epoch % 2 == 0:
        try:
            import umap.umap_ as umap_module

            # Only proceed if we have embeddings
            if len(self._val_backbone_embeddings) > 0 and len(self._val_quantized_embeddings) > 0:
                # Concatenate all accumulated embeddings
                backbone_emb_flat = torch.cat(self._val_backbone_embeddings, dim=0)
                quantized_emb_flat = torch.cat(self._val_quantized_embeddings, dim=0)

                # Flatten spatial dimensions: [B, C, H, W] -> [B*H*W, C]
                backbone_emb_flat = backbone_emb_flat.permute(0, 2, 3, 1).reshape(-1, backbone_emb_flat.shape[1])
                quantized_emb_flat = quantized_emb_flat.permute(0, 2, 3, 1).reshape(-1, quantized_emb_flat.shape[1])

                # Convert to numpy
                backbone_emb_np = backbone_emb_flat.numpy()
                quantized_emb_np = quantized_emb_flat.numpy()

                # Limit samples for performance (take subset if too large)
                max_samples = 10000
                if len(backbone_emb_np) > max_samples:
                    indices = np.random.choice(len(backbone_emb_np), max_samples, replace=False)
                    backbone_emb_np = backbone_emb_np[indices]
                    quantized_emb_np = quantized_emb_np[indices]

                # Generate 2D UMAP
                fig_2d, axs_2d = plt.subplots(1, 2, figsize=(12, 6))

                proj_2d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(backbone_emb_np)
                axs_2d[0].scatter(proj_2d_backbone[:, 0], proj_2d_backbone[:, 1], alpha=0.3)
                axs_2d[0].set_title('2D UMAP: Backbone Embeddings')

                proj_2d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(quantized_emb_np)
                axs_2d[1].scatter(proj_2d_quantized[:, 0], proj_2d_quantized[:, 1], alpha=0.3)
                axs_2d[1].set_title('2D UMAP: Quantized Embeddings')

                # Convert 2D plot to image and log
                fig_2d.canvas.draw()
                img_2d = np.frombuffer(fig_2d.canvas.tostring_rgb(), dtype=np.uint8)
                img_2d = img_2d.reshape(fig_2d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_2d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"2d_embeddings_epoch_{self.current_epoch}", 
                        img_2d, 
                        iteration=self.current_epoch
                    )

                # Generate 3D UMAP
                fig_3d = plt.figure(figsize=(12, 6))
                ax1 = fig_3d.add_subplot(121, projection='3d')
                ax2 = fig_3d.add_subplot(122, projection='3d')

                proj_3d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(backbone_emb_np)
                ax1.scatter(proj_3d_backbone[:, 0], proj_3d_backbone[:, 1], proj_3d_backbone[:, 2], alpha=0.3)
                ax1.set_title('3D UMAP: Backbone Embeddings')

                proj_3d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(quantized_emb_np)
                ax2.scatter(proj_3d_quantized[:, 0], proj_3d_quantized[:, 1], proj_3d_quantized[:, 2], alpha=0.3)
                ax2.set_title('3D UMAP: Quantized Embeddings')

                # Convert 3D plot to image and log
                fig_3d.canvas.draw()
                img_3d = np.frombuffer(fig_3d.canvas.tostring_rgb(), dtype=np.uint8)
                img_3d = img_3d.reshape(fig_3d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_3d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"3d_embeddings_epoch_{self.current_epoch}", 
                        img_3d, 
                        iteration=self.current_epoch
                    )

            # Clear accumulated embeddings after logging
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

        except Exception as e:
            if self.clearml_logger:
                self.clearml_logger.report_text(
                    f"UMAP visualization failed at epoch {self.current_epoch}: {e}"
                )
            # Clear embeddings even if visualization failed
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

    # Save per-epoch embedding (first validation batch only)
    try:
        if self._first_val_batch_features is not None:
            emb_path = os.path.join(
                self.embedding_dir,
                f"val_embedding_epoch{self.current_epoch}.pt"
            )
            torch.save(self._first_val_batch_features, emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Saved small embedding: {emb_path}")
            # Reset for next epoch
            self._first_val_batch_features = None
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed saving epoch embedding: {e}")
configure_optimizers
configure_optimizers()

Configure optimizer - only trainable parameters.

Source code in embeddings_squeeze\models\lightning_module.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def configure_optimizers(self):
    """Configure optimizer - only trainable parameters."""
    params = []

    # Add adapter parameters if present
    if self.feature_adapter is not None:
        params += list(self.feature_adapter.parameters())

    # Add quantizer parameters if present
    if self.quantizer is not None:
        params += list(self.quantizer.parameters())

    # Add backbone parameters if not frozen
    if self.feature_adapter is None:
        params += [p for p in self.backbone.parameters() if p.requires_grad]

    # Remove duplicates
    params = list({id(p): p for p in params}.values())

    if not params:
        raise ValueError("No trainable parameters found!")

    return torch.optim.AdamW(params, lr=self.learning_rate)
on_train_start
on_train_start()

Ensure frozen backbone stays in eval mode.

Source code in embeddings_squeeze\models\lightning_module.py
494
495
496
497
def on_train_start(self):
    """Ensure frozen backbone stays in eval mode."""
    if self.feature_adapter is not None:
        self.backbone.eval()

BaselineSegmentationModule

BaselineSegmentationModule(
    backbone,
    num_classes=21,
    learning_rate=0.0001,
    loss_type="ce",
    class_weights=None,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for baseline segmentation training.

Wraps segmentation backbone without Vector Quantization for comparison.

Source code in embeddings_squeeze\models\baseline_module.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    backbone: SegmentationBackbone,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'clearml_logger'])

    self.backbone = backbone
    self.num_classes = num_classes
    self.learning_rate = learning_rate

    # Segmentation loss
    if class_weights is not None:
        weight = torch.tensor(class_weights, dtype=torch.float32)
        self.seg_criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
    else:
        self.seg_criterion = nn.CrossEntropyLoss(ignore_index=255)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger
Functions
forward
forward(images)

Forward pass through backbone.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, images):
    """
    Forward pass through backbone.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    output = self.backbone(images)
    # Handle both dict and tensor returns
    if isinstance(output, dict):
        return output['out']
    return output
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\baseline_module.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', seg_loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\baseline_module.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations.

Source code in embeddings_squeeze\models\baseline_module.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )
configure_optimizers
configure_optimizers()

Configure optimizer.

Source code in embeddings_squeeze\models\baseline_module.py
228
229
230
231
232
def configure_optimizers(self):
    """Configure optimizer."""
    # Optimize only trainable params
    params = [p for p in self.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=self.learning_rate)
predict
predict(images)

Predict segmentation masks.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
predictions

Segmentation predictions [B, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def predict(self, images):
    """
    Predict segmentation masks.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        predictions: Segmentation predictions [B, H, W]
    """
    self.eval()
    with torch.no_grad():
        output = self(images)
        predictions = output.argmax(dim=1)
    return predictions
predict_logits
predict_logits(images)

Predict segmentation logits.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
logits

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def predict_logits(self, images):
    """
    Predict segmentation logits.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        logits: Segmentation logits [B, num_classes, H, W]
    """
    self.eval()
    with torch.no_grad():
        logits = self(images)
    return logits

Modules

backbones

Segmentation backbone implementations.

Classes
SegmentationBackbone
SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()
Attributes
feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions
extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass
ViTSegmentationBackbone
ViTSegmentationBackbone(
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

ViT-based segmentation backbone.

Uses ViT-B/32 as backbone with custom segmentation head.

Source code in embeddings_squeeze\models\backbones\vit.py
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()
    base_vit = model_fn(weights=weights)
    self.backbone = _ViTBackboneWrapper(base_vit, freeze=freeze_backbone)
    self.classifier = _ViTSegmentationHead(self.backbone.hidden_dim, num_classes)

    self._num_classes = num_classes
Attributes
feature_dim property
feature_dim

Return ViT hidden dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract ViT backbone feature maps.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, hidden_dim, H/patch, W/patch]

Source code in embeddings_squeeze\models\backbones\vit.py
80
81
82
83
84
85
86
87
88
89
90
91
92
def extract_features(self, images, detach=True):
    """
    Extract ViT backbone feature maps.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, hidden_dim, H/patch, W/patch]
    """
    feats = self.backbone(images)['out']
    return feats.detach() if detach else feats
forward
forward(images)

Full ViT segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\vit.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, images):
    """
    Full ViT segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    logits = self.classifier(features)
    logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
    return {'out': logits}
DeepLabV3SegmentationBackbone
DeepLabV3SegmentationBackbone(
    weights_name="COCO_WITH_VOC_LABELS_V1",
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

DeepLabV3-ResNet50 segmentation backbone.

Uses pre-trained DeepLabV3-ResNet50 for segmentation.

Source code in embeddings_squeeze\models\backbones\deeplab.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    weights_name='COCO_WITH_VOC_LABELS_V1',
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()

    weights = getattr(DeepLabV3_ResNet50_Weights, weights_name)
    model = deeplabv3_resnet50(weights=weights)

    self.backbone = model.backbone
    self.classifier = model.classifier

    self._num_classes = num_classes

    if freeze_backbone:
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
Attributes
feature_dim property
feature_dim

Return DeepLabV3 feature dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract DeepLabV3 backbone features.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, 2048, H/8, W/8]

Source code in embeddings_squeeze\models\backbones\deeplab.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def extract_features(self, images, detach=True):
    """
    Extract DeepLabV3 backbone features.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, 2048, H/8, W/8]
    """
    with torch.set_grad_enabled(not detach):
        features = self.backbone(images)['out']
    return features
forward
forward(images)

Full DeepLabV3 segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\deeplab.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def forward(self, images):
    """
    Full DeepLabV3 segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    output = self.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear')
    return {'out': output}
Modules
base

Abstract base class for segmentation backbones.

Classes
SegmentationBackbone
SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()
Attributes
feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions
extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass
deeplab

DeepLabV3-ResNet50 segmentation backbone implementation.

Classes
DeepLabV3SegmentationBackbone
DeepLabV3SegmentationBackbone(
    weights_name="COCO_WITH_VOC_LABELS_V1",
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

DeepLabV3-ResNet50 segmentation backbone.

Uses pre-trained DeepLabV3-ResNet50 for segmentation.

Source code in embeddings_squeeze\models\backbones\deeplab.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    weights_name='COCO_WITH_VOC_LABELS_V1',
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()

    weights = getattr(DeepLabV3_ResNet50_Weights, weights_name)
    model = deeplabv3_resnet50(weights=weights)

    self.backbone = model.backbone
    self.classifier = model.classifier

    self._num_classes = num_classes

    if freeze_backbone:
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
Attributes
feature_dim property
feature_dim

Return DeepLabV3 feature dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract DeepLabV3 backbone features.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, 2048, H/8, W/8]

Source code in embeddings_squeeze\models\backbones\deeplab.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def extract_features(self, images, detach=True):
    """
    Extract DeepLabV3 backbone features.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, 2048, H/8, W/8]
    """
    with torch.set_grad_enabled(not detach):
        features = self.backbone(images)['out']
    return features
forward
forward(images)

Full DeepLabV3 segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\deeplab.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def forward(self, images):
    """
    Full DeepLabV3 segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    output = self.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear')
    return {'out': output}
vit

ViT-based segmentation backbone implementation.

Classes
ViTSegmentationBackbone
ViTSegmentationBackbone(
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

ViT-based segmentation backbone.

Uses ViT-B/32 as backbone with custom segmentation head.

Source code in embeddings_squeeze\models\backbones\vit.py
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()
    base_vit = model_fn(weights=weights)
    self.backbone = _ViTBackboneWrapper(base_vit, freeze=freeze_backbone)
    self.classifier = _ViTSegmentationHead(self.backbone.hidden_dim, num_classes)

    self._num_classes = num_classes
Attributes
feature_dim property
feature_dim

Return ViT hidden dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract ViT backbone feature maps.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, hidden_dim, H/patch, W/patch]

Source code in embeddings_squeeze\models\backbones\vit.py
80
81
82
83
84
85
86
87
88
89
90
91
92
def extract_features(self, images, detach=True):
    """
    Extract ViT backbone feature maps.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, hidden_dim, H/patch, W/patch]
    """
    feats = self.backbone(images)['out']
    return feats.detach() if detach else feats
forward
forward(images)

Full ViT segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\vit.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, images):
    """
    Full ViT segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    logits = self.classifier(features)
    logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
    return {'out': logits}

baseline_module

PyTorch Lightning module for baseline segmentation training without VQ.

Classes
BaselineSegmentationModule
BaselineSegmentationModule(
    backbone,
    num_classes=21,
    learning_rate=0.0001,
    loss_type="ce",
    class_weights=None,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for baseline segmentation training.

Wraps segmentation backbone without Vector Quantization for comparison.

Source code in embeddings_squeeze\models\baseline_module.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    backbone: SegmentationBackbone,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'clearml_logger'])

    self.backbone = backbone
    self.num_classes = num_classes
    self.learning_rate = learning_rate

    # Segmentation loss
    if class_weights is not None:
        weight = torch.tensor(class_weights, dtype=torch.float32)
        self.seg_criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
    else:
        self.seg_criterion = nn.CrossEntropyLoss(ignore_index=255)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger
Functions
forward
forward(images)

Forward pass through backbone.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, images):
    """
    Forward pass through backbone.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    output = self.backbone(images)
    # Handle both dict and tensor returns
    if isinstance(output, dict):
        return output['out']
    return output
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\baseline_module.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', seg_loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\baseline_module.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations.

Source code in embeddings_squeeze\models\baseline_module.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )
configure_optimizers
configure_optimizers()

Configure optimizer.

Source code in embeddings_squeeze\models\baseline_module.py
228
229
230
231
232
def configure_optimizers(self):
    """Configure optimizer."""
    # Optimize only trainable params
    params = [p for p in self.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=self.learning_rate)
predict
predict(images)

Predict segmentation masks.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
predictions

Segmentation predictions [B, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def predict(self, images):
    """
    Predict segmentation masks.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        predictions: Segmentation predictions [B, H, W]
    """
    self.eval()
    with torch.no_grad():
        output = self(images)
        predictions = output.argmax(dim=1)
    return predictions
predict_logits
predict_logits(images)

Predict segmentation logits.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
logits

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def predict_logits(self, images):
    """
    Predict segmentation logits.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        logits: Segmentation logits [B, num_classes, H, W]
    """
    self.eval()
    with torch.no_grad():
        logits = self(images)
    return logits

lightning_module

PyTorch Lightning module for VQ compression training with advanced features. Supports multiple quantizers (VQ, FSQ, LFQ, RVQ), adapters, and loss functions.

Classes
VQSqueezeModule
VQSqueezeModule(
    backbone,
    quantizer=None,
    num_classes=21,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    loss_type="ce",
    class_weights=None,
    add_adapter=False,
    feature_dim=2048,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for VQ compression training.

Features: - Multiple quantizer support (VQ, FSQ, LFQ, RVQ) - Adapter layers for fine-tuning frozen backbones - Advanced loss functions (CE, Dice, Focal, Combined) - Embedding extraction and saving

Source code in embeddings_squeeze\models\lightning_module.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    backbone: SegmentationBackbone,
    quantizer: Optional[nn.Module] = None,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    vq_loss_weight: float = 0.1,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    add_adapter: bool = False,
    feature_dim: int = 2048,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'quantizer', 'clearml_logger'])

    self.num_classes = num_classes
    self.learning_rate = learning_rate
    self.vq_loss_weight = vq_loss_weight
    self.loss_type = loss_type
    self.add_adapter = add_adapter
    self.feature_dim = feature_dim

    # Setup backbone with optional adapters
    self.backbone = backbone
    self._setup_backbone_with_adapters(feature_dim, add_adapter)

    # Quantizer (optional)
    self.quantizer = quantizer

    # Loss function
    self.criterion = self._init_loss(loss_type, class_weights)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger

    # Embedding storage (per-epoch, first batch only)
    self.embedding_dir = "embeddings"
    os.makedirs(self.embedding_dir, exist_ok=True)
    self._first_val_batch_features = None

    # UMAP visualization storage
    self._val_backbone_embeddings = []
    self._val_quantized_embeddings = []
Functions
forward
forward(images)

Forward pass through backbone + optional quantizer + decoder.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

quant_loss

Quantization loss (0 if no quantizer)

original_features

Extracted features (before quantization)

quantized_features

Features after quantization (same as original if no quantizer)

Source code in embeddings_squeeze\models\lightning_module.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(self, images):
    """
    Forward pass through backbone + optional quantizer + decoder.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
        quant_loss: Quantization loss (0 if no quantizer)
        original_features: Extracted features (before quantization)
        quantized_features: Features after quantization (same as original if no quantizer)
    """
    # Extract features
    features = self.backbone.extract_features(images, detach=self.feature_adapter is not None)

    # Apply adapter if present
    if self.feature_adapter is not None:
        features = features + self.feature_adapter(features)

    # Store original features for embedding extraction
    original_features = features

    # Quantize if quantizer is present
    quant_loss = torch.tensor(0.0, device=images.device)
    quantized_features = original_features  # Default to original if no quantizer
    if self.quantizer is not None:
        features, quant_loss = self.quantizer.quantize_spatial(features)
        quantized_features = features

    # Decode to segmentation logits
    output = self.backbone.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear', align_corners=False)

    return output, quant_loss, original_features, quantized_features
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\lightning_module.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, _, _ = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\lightning_module.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, backbone_features, quantized_features = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    # Accumulate embeddings for UMAP visualization
    self._val_backbone_embeddings.append(backbone_features.detach().cpu())
    self._val_quantized_embeddings.append(quantized_features.detach().cpu())

    # Save only first batch features for this epoch
    if batch_idx == 0:
        self._first_val_batch_features = backbone_features.detach().cpu()

    return loss
on_validation_epoch_start
on_validation_epoch_start()

Clear accumulated embeddings at the start of each validation epoch.

Source code in embeddings_squeeze\models\lightning_module.py
270
271
272
273
def on_validation_epoch_start(self):
    """Clear accumulated embeddings at the start of each validation epoch."""
    self._val_backbone_embeddings.clear()
    self._val_quantized_embeddings.clear()
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations and save embeddings.

Source code in embeddings_squeeze\models\lightning_module.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations and save embeddings."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    keys = [
        "train/loss", "val/loss", "train/iou", "val/iou",
        "train/precision", "val/precision", "train/recall", "val/recall",
        "train/f1", "val/f1"
    ]
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )

    # Generate UMAP visualizations on even epochs
    if self.current_epoch % 2 == 0:
        try:
            import umap.umap_ as umap_module

            # Only proceed if we have embeddings
            if len(self._val_backbone_embeddings) > 0 and len(self._val_quantized_embeddings) > 0:
                # Concatenate all accumulated embeddings
                backbone_emb_flat = torch.cat(self._val_backbone_embeddings, dim=0)
                quantized_emb_flat = torch.cat(self._val_quantized_embeddings, dim=0)

                # Flatten spatial dimensions: [B, C, H, W] -> [B*H*W, C]
                backbone_emb_flat = backbone_emb_flat.permute(0, 2, 3, 1).reshape(-1, backbone_emb_flat.shape[1])
                quantized_emb_flat = quantized_emb_flat.permute(0, 2, 3, 1).reshape(-1, quantized_emb_flat.shape[1])

                # Convert to numpy
                backbone_emb_np = backbone_emb_flat.numpy()
                quantized_emb_np = quantized_emb_flat.numpy()

                # Limit samples for performance (take subset if too large)
                max_samples = 10000
                if len(backbone_emb_np) > max_samples:
                    indices = np.random.choice(len(backbone_emb_np), max_samples, replace=False)
                    backbone_emb_np = backbone_emb_np[indices]
                    quantized_emb_np = quantized_emb_np[indices]

                # Generate 2D UMAP
                fig_2d, axs_2d = plt.subplots(1, 2, figsize=(12, 6))

                proj_2d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(backbone_emb_np)
                axs_2d[0].scatter(proj_2d_backbone[:, 0], proj_2d_backbone[:, 1], alpha=0.3)
                axs_2d[0].set_title('2D UMAP: Backbone Embeddings')

                proj_2d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(quantized_emb_np)
                axs_2d[1].scatter(proj_2d_quantized[:, 0], proj_2d_quantized[:, 1], alpha=0.3)
                axs_2d[1].set_title('2D UMAP: Quantized Embeddings')

                # Convert 2D plot to image and log
                fig_2d.canvas.draw()
                img_2d = np.frombuffer(fig_2d.canvas.tostring_rgb(), dtype=np.uint8)
                img_2d = img_2d.reshape(fig_2d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_2d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"2d_embeddings_epoch_{self.current_epoch}", 
                        img_2d, 
                        iteration=self.current_epoch
                    )

                # Generate 3D UMAP
                fig_3d = plt.figure(figsize=(12, 6))
                ax1 = fig_3d.add_subplot(121, projection='3d')
                ax2 = fig_3d.add_subplot(122, projection='3d')

                proj_3d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(backbone_emb_np)
                ax1.scatter(proj_3d_backbone[:, 0], proj_3d_backbone[:, 1], proj_3d_backbone[:, 2], alpha=0.3)
                ax1.set_title('3D UMAP: Backbone Embeddings')

                proj_3d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(quantized_emb_np)
                ax2.scatter(proj_3d_quantized[:, 0], proj_3d_quantized[:, 1], proj_3d_quantized[:, 2], alpha=0.3)
                ax2.set_title('3D UMAP: Quantized Embeddings')

                # Convert 3D plot to image and log
                fig_3d.canvas.draw()
                img_3d = np.frombuffer(fig_3d.canvas.tostring_rgb(), dtype=np.uint8)
                img_3d = img_3d.reshape(fig_3d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_3d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"3d_embeddings_epoch_{self.current_epoch}", 
                        img_3d, 
                        iteration=self.current_epoch
                    )

            # Clear accumulated embeddings after logging
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

        except Exception as e:
            if self.clearml_logger:
                self.clearml_logger.report_text(
                    f"UMAP visualization failed at epoch {self.current_epoch}: {e}"
                )
            # Clear embeddings even if visualization failed
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

    # Save per-epoch embedding (first validation batch only)
    try:
        if self._first_val_batch_features is not None:
            emb_path = os.path.join(
                self.embedding_dir,
                f"val_embedding_epoch{self.current_epoch}.pt"
            )
            torch.save(self._first_val_batch_features, emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Saved small embedding: {emb_path}")
            # Reset for next epoch
            self._first_val_batch_features = None
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed saving epoch embedding: {e}")
configure_optimizers
configure_optimizers()

Configure optimizer - only trainable parameters.

Source code in embeddings_squeeze\models\lightning_module.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def configure_optimizers(self):
    """Configure optimizer - only trainable parameters."""
    params = []

    # Add adapter parameters if present
    if self.feature_adapter is not None:
        params += list(self.feature_adapter.parameters())

    # Add quantizer parameters if present
    if self.quantizer is not None:
        params += list(self.quantizer.parameters())

    # Add backbone parameters if not frozen
    if self.feature_adapter is None:
        params += [p for p in self.backbone.parameters() if p.requires_grad]

    # Remove duplicates
    params = list({id(p): p for p in params}.values())

    if not params:
        raise ValueError("No trainable parameters found!")

    return torch.optim.AdamW(params, lr=self.learning_rate)
on_train_start
on_train_start()

Ensure frozen backbone stays in eval mode.

Source code in embeddings_squeeze\models\lightning_module.py
494
495
496
497
def on_train_start(self):
    """Ensure frozen backbone stays in eval mode."""
    if self.feature_adapter is not None:
        self.backbone.eval()

losses

Loss functions for segmentation tasks. Includes: Cross Entropy, Dice Loss, Focal Loss, and Combined Loss.

Classes
DiceLoss
DiceLoss(smooth=1.0)

Bases: Module

Dice Loss for multi-class segmentation

Source code in embeddings_squeeze\models\losses.py
13
14
15
def __init__(self, smooth: float = 1.0):
    super().__init__()
    self.smooth = smooth
FocalLoss
FocalLoss(alpha=1.0, gamma=2.0, reduction='mean')

Bases: Module

Focal Loss for handling class imbalance (multi-class via CE per-pixel)

Source code in embeddings_squeeze\models\losses.py
40
41
42
43
44
def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction
CombinedLoss
CombinedLoss(
    ce_weight=1.0,
    dice_weight=1.0,
    focal_weight=0.5,
    class_weights=None,
)

Bases: Module

Combined loss: CE + Dice + Focal. Returns (total, ce, dice, focal).

Source code in embeddings_squeeze\models\losses.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, 
    ce_weight: float = 1.0, 
    dice_weight: float = 1.0, 
    focal_weight: float = 0.5, 
    class_weights=None
):
    super().__init__()
    # class_weights can be None or a tensor/list
    if class_weights is not None:
        # Leave tensor creation to forward (to place on correct device) but store raw
        self._class_weights = class_weights
    else:
        self._class_weights = None

    self.ce_weight = ce_weight
    self.dice_weight = dice_weight
    self.focal_weight = focal_weight

    # Instantiate component losses
    self.dice_loss = DiceLoss()
    self.focal_loss = FocalLoss()

quantizers

Vector Quantization implementations using vector_quantize_pytorch library. Supports: VQ-VAE, FSQ, LFQ, and Residual VQ.

Classes
BaseQuantizer
BaseQuantizer(input_dim)

Bases: Module

Base class for all quantizers

Source code in embeddings_squeeze\models\quantizers.py
14
15
16
def __init__(self, input_dim: int):
    super().__init__()
    self.input_dim = input_dim
Functions
quantize_spatial
quantize_spatial(features)

Quantize spatial features [B, C, H, W]

Parameters:

Name Type Description Default
features Tensor

Tensor of shape [B, C, H, W]

required

Returns:

Name Type Description
quantized

Quantized features [B, C, H, W]

loss

Quantization loss (scalar)

Source code in embeddings_squeeze\models\quantizers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def quantize_spatial(self, features: torch.Tensor):
    """
    Quantize spatial features [B, C, H, W]

    Args:
        features: Tensor of shape [B, C, H, W]

    Returns:
        quantized: Quantized features [B, C, H, W]
        loss: Quantization loss (scalar)
    """
    B, C, H, W = features.shape
    # Transform [B, C, H, W] -> [B, H*W, C]
    seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)

    # Quantize
    quantized, indices, loss = self.forward(seq)

    # Transform back [B, H*W, C] -> [B, C, H, W]
    quantized = quantized.reshape(B, H, W, C).permute(0, 3, 1, 2)

    # Handle loss (may be tensor with multiple elements)
    if isinstance(loss, torch.Tensor) and loss.numel() > 1:
        loss = loss.mean()

    return quantized, loss
VQWithProjection
VQWithProjection(
    input_dim,
    codebook_size=512,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Vector Quantization (VQ-VAE) with projections

Uses EMA for codebook updates (no gradients needed for codebook) ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection (e.g., 2048 -> 64)
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Vector Quantization
    self.vq = VectorQuantize(
        dim=bottleneck_dim,
        codebook_size=codebook_size,
        decay=decay,  # EMA decay for codebook
        commitment_weight=commitment_weight  # Commitment loss weight
    )

    # Up projection (64 -> 2048)
    self.project_out = nn.Linear(bottleneck_dim, input_dim)
FSQWithProjection
FSQWithProjection(input_dim, levels=None)

Bases: BaseQuantizer

Finite Scalar Quantization (FSQ)

Quantization without codebook - each dimension quantized independently ~10 bits per vector at levels=[8,5,5,5]

Source code in embeddings_squeeze\models\quantizers.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(self, input_dim: int, levels: list = None):
    super().__init__(input_dim)
    if levels is None:
        levels = [8, 5, 5, 5]  # 8*5*5*5 = 1000 codes ≈ 2^10

    self.num_levels = len(levels)

    # Projection to quantization space
    self.project_in = nn.Linear(input_dim, self.num_levels)

    # FSQ quantization
    self.fsq = FSQ(levels=levels, dim=self.num_levels)

    # Projection back
    self.project_out = nn.Linear(self.num_levels, input_dim)
LFQWithProjection
LFQWithProjection(
    input_dim,
    codebook_size=512,
    entropy_loss_weight=0.1,
    diversity_gamma=0.1,
    spherical=False,
)

Bases: BaseQuantizer

Lookup-Free Quantization (LFQ)

Uses entropy loss for code diversity ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512,
    entropy_loss_weight: float = 0.1, 
    diversity_gamma: float = 0.1, 
    spherical: bool = False
):
    super().__init__(input_dim)
    # Quantization dimension = log2(codebook_size)
    self.quant_dim = int(math.log2(codebook_size))

    # Projection with normalization
    self.project_in = nn.Sequential(
        nn.Linear(input_dim, self.quant_dim),
        nn.LayerNorm(self.quant_dim)
    )

    # LFQ quantization
    self.lfq = LFQ(
        dim=self.quant_dim,
        codebook_size=codebook_size,
        entropy_loss_weight=entropy_loss_weight,
        diversity_gamma=diversity_gamma,
        spherical=spherical
    )

    # Projection back
    self.project_out = nn.Linear(self.quant_dim, input_dim)
ResidualVQWithProjection
ResidualVQWithProjection(
    input_dim,
    num_quantizers=4,
    codebook_size=256,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Residual Vector Quantization (RVQ)

Multi-level quantization - each level quantizes the residual of the previous 32 bits per vector at num_quantizers=4, codebook_size=256 (4*8 bits)

Source code in embeddings_squeeze\models\quantizers.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self, 
    input_dim: int, 
    num_quantizers: int = 4,
    codebook_size: int = 256, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Residual VQ
    self.residual_vq = ResidualVQ(
        dim=bottleneck_dim,
        num_quantizers=num_quantizers,  # Number of levels
        codebook_size=codebook_size,
        decay=decay,
        commitment_weight=commitment_weight
    )

    # Up projection
    self.project_out = nn.Linear(bottleneck_dim, input_dim)