Skip to content

Полная документация пакета

embeddings_squeeze: Vector Quantization for Segmentation Model Compression

Classes

VQWithProjection

VQWithProjection(
    input_dim,
    codebook_size=512,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Vector Quantization (VQ-VAE) with projections

Uses EMA for codebook updates (no gradients needed for codebook) ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection (e.g., 2048 -> 64)
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Vector Quantization
    self.vq = VectorQuantize(
        dim=bottleneck_dim,
        codebook_size=codebook_size,
        decay=decay,  # EMA decay for codebook
        commitment_weight=commitment_weight  # Commitment loss weight
    )

    # Up projection (64 -> 2048)
    self.project_out = nn.Linear(bottleneck_dim, input_dim)

FSQWithProjection

FSQWithProjection(input_dim, levels=None)

Bases: BaseQuantizer

Finite Scalar Quantization (FSQ)

Quantization without codebook - each dimension quantized independently ~10 bits per vector at levels=[8,5,5,5]

Source code in embeddings_squeeze\models\quantizers.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(self, input_dim: int, levels: list = None):
    super().__init__(input_dim)
    if levels is None:
        levels = [8, 5, 5, 5]  # 8*5*5*5 = 1000 codes ≈ 2^10

    self.num_levels = len(levels)

    # Projection to quantization space
    self.project_in = nn.Linear(input_dim, self.num_levels)

    # FSQ quantization
    self.fsq = FSQ(levels=levels, dim=self.num_levels)

    # Projection back
    self.project_out = nn.Linear(self.num_levels, input_dim)

LFQWithProjection

LFQWithProjection(
    input_dim,
    codebook_size=512,
    entropy_loss_weight=0.1,
    diversity_gamma=0.1,
    spherical=False,
)

Bases: BaseQuantizer

Lookup-Free Quantization (LFQ)

Uses entropy loss for code diversity ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512,
    entropy_loss_weight: float = 0.1, 
    diversity_gamma: float = 0.1, 
    spherical: bool = False
):
    super().__init__(input_dim)
    # Quantization dimension = log2(codebook_size)
    self.quant_dim = int(math.log2(codebook_size))

    # Projection with normalization
    self.project_in = nn.Sequential(
        nn.Linear(input_dim, self.quant_dim),
        nn.LayerNorm(self.quant_dim)
    )

    # LFQ quantization
    self.lfq = LFQ(
        dim=self.quant_dim,
        codebook_size=codebook_size,
        entropy_loss_weight=entropy_loss_weight,
        diversity_gamma=diversity_gamma,
        spherical=spherical
    )

    # Projection back
    self.project_out = nn.Linear(self.quant_dim, input_dim)

ResidualVQWithProjection

ResidualVQWithProjection(
    input_dim,
    num_quantizers=4,
    codebook_size=256,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Residual Vector Quantization (RVQ)

Multi-level quantization - each level quantizes the residual of the previous 32 bits per vector at num_quantizers=4, codebook_size=256 (4*8 bits)

Source code in embeddings_squeeze\models\quantizers.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self, 
    input_dim: int, 
    num_quantizers: int = 4,
    codebook_size: int = 256, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Residual VQ
    self.residual_vq = ResidualVQ(
        dim=bottleneck_dim,
        num_quantizers=num_quantizers,  # Number of levels
        codebook_size=codebook_size,
        decay=decay,
        commitment_weight=commitment_weight
    )

    # Up projection
    self.project_out = nn.Linear(bottleneck_dim, input_dim)

BaseQuantizer

BaseQuantizer(input_dim)

Bases: Module

Base class for all quantizers

Source code in embeddings_squeeze\models\quantizers.py
14
15
16
def __init__(self, input_dim: int):
    super().__init__()
    self.input_dim = input_dim

Functions

quantize_spatial
quantize_spatial(features)

Quantize spatial features [B, C, H, W]

Parameters:

Name Type Description Default
features Tensor

Tensor of shape [B, C, H, W]

required

Returns:

Name Type Description
quantized

Quantized features [B, C, H, W]

loss

Quantization loss (scalar)

Source code in embeddings_squeeze\models\quantizers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def quantize_spatial(self, features: torch.Tensor):
    """
    Quantize spatial features [B, C, H, W]

    Args:
        features: Tensor of shape [B, C, H, W]

    Returns:
        quantized: Quantized features [B, C, H, W]
        loss: Quantization loss (scalar)
    """
    B, C, H, W = features.shape
    # Transform [B, C, H, W] -> [B, H*W, C]
    seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)

    # Quantize
    quantized, indices, loss = self.forward(seq)

    # Transform back [B, H*W, C] -> [B, C, H, W]
    quantized = quantized.reshape(B, H, W, C).permute(0, 3, 1, 2)

    # Handle loss (may be tensor with multiple elements)
    if isinstance(loss, torch.Tensor) and loss.numel() > 1:
        loss = loss.mean()

    return quantized, loss

SegmentationBackbone

SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()

Attributes

feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions

extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass

VQSqueezeModule

VQSqueezeModule(
    backbone,
    quantizer=None,
    num_classes=21,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    loss_type="ce",
    class_weights=None,
    add_adapter=False,
    feature_dim=2048,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for VQ compression training.

Features: - Multiple quantizer support (VQ, FSQ, LFQ, RVQ) - Adapter layers for fine-tuning frozen backbones - Advanced loss functions (CE, Dice, Focal, Combined) - Embedding extraction and saving

Source code in embeddings_squeeze\models\lightning_module.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    backbone: SegmentationBackbone,
    quantizer: Optional[nn.Module] = None,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    vq_loss_weight: float = 0.1,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    add_adapter: bool = False,
    feature_dim: int = 2048,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'quantizer', 'clearml_logger'])

    self.num_classes = num_classes
    self.learning_rate = learning_rate
    self.vq_loss_weight = vq_loss_weight
    self.loss_type = loss_type
    self.add_adapter = add_adapter
    self.feature_dim = feature_dim

    # Setup backbone with optional adapters
    self.backbone = backbone
    self._setup_backbone_with_adapters(feature_dim, add_adapter)

    # Quantizer (optional)
    self.quantizer = quantizer

    # Loss function
    self.criterion = self._init_loss(loss_type, class_weights)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger

    # Embedding storage (per-epoch, first batch only)
    self.embedding_dir = "embeddings"
    os.makedirs(self.embedding_dir, exist_ok=True)
    self._first_val_batch_features = None

    # UMAP visualization storage
    self._val_backbone_embeddings = []
    self._val_quantized_embeddings = []

Functions

forward
forward(images)

Forward pass through backbone + optional quantizer + decoder.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

quant_loss

Quantization loss (0 if no quantizer)

original_features

Extracted features (before quantization)

quantized_features

Features after quantization (same as original if no quantizer)

Source code in embeddings_squeeze\models\lightning_module.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(self, images):
    """
    Forward pass through backbone + optional quantizer + decoder.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
        quant_loss: Quantization loss (0 if no quantizer)
        original_features: Extracted features (before quantization)
        quantized_features: Features after quantization (same as original if no quantizer)
    """
    # Extract features
    features = self.backbone.extract_features(images, detach=self.feature_adapter is not None)

    # Apply adapter if present
    if self.feature_adapter is not None:
        features = features + self.feature_adapter(features)

    # Store original features for embedding extraction
    original_features = features

    # Quantize if quantizer is present
    quant_loss = torch.tensor(0.0, device=images.device)
    quantized_features = original_features  # Default to original if no quantizer
    if self.quantizer is not None:
        features, quant_loss = self.quantizer.quantize_spatial(features)
        quantized_features = features

    # Decode to segmentation logits
    output = self.backbone.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear', align_corners=False)

    return output, quant_loss, original_features, quantized_features
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\lightning_module.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, _, _ = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\lightning_module.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, backbone_features, quantized_features = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    # Accumulate embeddings for UMAP visualization
    self._val_backbone_embeddings.append(backbone_features.detach().cpu())
    self._val_quantized_embeddings.append(quantized_features.detach().cpu())

    # Save only first batch features for this epoch
    if batch_idx == 0:
        self._first_val_batch_features = backbone_features.detach().cpu()

    return loss
on_validation_epoch_start
on_validation_epoch_start()

Clear accumulated embeddings at the start of each validation epoch.

Source code in embeddings_squeeze\models\lightning_module.py
270
271
272
273
def on_validation_epoch_start(self):
    """Clear accumulated embeddings at the start of each validation epoch."""
    self._val_backbone_embeddings.clear()
    self._val_quantized_embeddings.clear()
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations and save embeddings.

Source code in embeddings_squeeze\models\lightning_module.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations and save embeddings."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    keys = [
        "train/loss", "val/loss", "train/iou", "val/iou",
        "train/precision", "val/precision", "train/recall", "val/recall",
        "train/f1", "val/f1"
    ]
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )

    # Generate UMAP visualizations on even epochs
    if self.current_epoch % 2 == 0:
        try:
            import umap.umap_ as umap_module

            # Only proceed if we have embeddings
            if len(self._val_backbone_embeddings) > 0 and len(self._val_quantized_embeddings) > 0:
                # Concatenate all accumulated embeddings
                backbone_emb_flat = torch.cat(self._val_backbone_embeddings, dim=0)
                quantized_emb_flat = torch.cat(self._val_quantized_embeddings, dim=0)

                # Flatten spatial dimensions: [B, C, H, W] -> [B*H*W, C]
                backbone_emb_flat = backbone_emb_flat.permute(0, 2, 3, 1).reshape(-1, backbone_emb_flat.shape[1])
                quantized_emb_flat = quantized_emb_flat.permute(0, 2, 3, 1).reshape(-1, quantized_emb_flat.shape[1])

                # Convert to numpy
                backbone_emb_np = backbone_emb_flat.numpy()
                quantized_emb_np = quantized_emb_flat.numpy()

                # Limit samples for performance (take subset if too large)
                max_samples = 10000
                if len(backbone_emb_np) > max_samples:
                    indices = np.random.choice(len(backbone_emb_np), max_samples, replace=False)
                    backbone_emb_np = backbone_emb_np[indices]
                    quantized_emb_np = quantized_emb_np[indices]

                # Generate 2D UMAP
                fig_2d, axs_2d = plt.subplots(1, 2, figsize=(12, 6))

                proj_2d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(backbone_emb_np)
                axs_2d[0].scatter(proj_2d_backbone[:, 0], proj_2d_backbone[:, 1], alpha=0.3)
                axs_2d[0].set_title('2D UMAP: Backbone Embeddings')

                proj_2d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(quantized_emb_np)
                axs_2d[1].scatter(proj_2d_quantized[:, 0], proj_2d_quantized[:, 1], alpha=0.3)
                axs_2d[1].set_title('2D UMAP: Quantized Embeddings')

                # Convert 2D plot to image and log
                fig_2d.canvas.draw()
                img_2d = np.frombuffer(fig_2d.canvas.tostring_rgb(), dtype=np.uint8)
                img_2d = img_2d.reshape(fig_2d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_2d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"2d_embeddings_epoch_{self.current_epoch}", 
                        img_2d, 
                        iteration=self.current_epoch
                    )

                # Generate 3D UMAP
                fig_3d = plt.figure(figsize=(12, 6))
                ax1 = fig_3d.add_subplot(121, projection='3d')
                ax2 = fig_3d.add_subplot(122, projection='3d')

                proj_3d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(backbone_emb_np)
                ax1.scatter(proj_3d_backbone[:, 0], proj_3d_backbone[:, 1], proj_3d_backbone[:, 2], alpha=0.3)
                ax1.set_title('3D UMAP: Backbone Embeddings')

                proj_3d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(quantized_emb_np)
                ax2.scatter(proj_3d_quantized[:, 0], proj_3d_quantized[:, 1], proj_3d_quantized[:, 2], alpha=0.3)
                ax2.set_title('3D UMAP: Quantized Embeddings')

                # Convert 3D plot to image and log
                fig_3d.canvas.draw()
                img_3d = np.frombuffer(fig_3d.canvas.tostring_rgb(), dtype=np.uint8)
                img_3d = img_3d.reshape(fig_3d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_3d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"3d_embeddings_epoch_{self.current_epoch}", 
                        img_3d, 
                        iteration=self.current_epoch
                    )

            # Clear accumulated embeddings after logging
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

        except Exception as e:
            if self.clearml_logger:
                self.clearml_logger.report_text(
                    f"UMAP visualization failed at epoch {self.current_epoch}: {e}"
                )
            # Clear embeddings even if visualization failed
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

    # Save per-epoch embedding (first validation batch only)
    try:
        if self._first_val_batch_features is not None:
            emb_path = os.path.join(
                self.embedding_dir,
                f"val_embedding_epoch{self.current_epoch}.pt"
            )
            torch.save(self._first_val_batch_features, emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Saved small embedding: {emb_path}")
            # Reset for next epoch
            self._first_val_batch_features = None
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed saving epoch embedding: {e}")
configure_optimizers
configure_optimizers()

Configure optimizer - only trainable parameters.

Source code in embeddings_squeeze\models\lightning_module.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def configure_optimizers(self):
    """Configure optimizer - only trainable parameters."""
    params = []

    # Add adapter parameters if present
    if self.feature_adapter is not None:
        params += list(self.feature_adapter.parameters())

    # Add quantizer parameters if present
    if self.quantizer is not None:
        params += list(self.quantizer.parameters())

    # Add backbone parameters if not frozen
    if self.feature_adapter is None:
        params += [p for p in self.backbone.parameters() if p.requires_grad]

    # Remove duplicates
    params = list({id(p): p for p in params}.values())

    if not params:
        raise ValueError("No trainable parameters found!")

    return torch.optim.AdamW(params, lr=self.learning_rate)
on_train_start
on_train_start()

Ensure frozen backbone stays in eval mode.

Source code in embeddings_squeeze\models\lightning_module.py
494
495
496
497
def on_train_start(self):
    """Ensure frozen backbone stays in eval mode."""
    if self.feature_adapter is not None:
        self.backbone.eval()

BaselineSegmentationModule

BaselineSegmentationModule(
    backbone,
    num_classes=21,
    learning_rate=0.0001,
    loss_type="ce",
    class_weights=None,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for baseline segmentation training.

Wraps segmentation backbone without Vector Quantization for comparison.

Source code in embeddings_squeeze\models\baseline_module.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    backbone: SegmentationBackbone,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'clearml_logger'])

    self.backbone = backbone
    self.num_classes = num_classes
    self.learning_rate = learning_rate

    # Segmentation loss
    if class_weights is not None:
        weight = torch.tensor(class_weights, dtype=torch.float32)
        self.seg_criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
    else:
        self.seg_criterion = nn.CrossEntropyLoss(ignore_index=255)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger

Functions

forward
forward(images)

Forward pass through backbone.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, images):
    """
    Forward pass through backbone.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    output = self.backbone(images)
    # Handle both dict and tensor returns
    if isinstance(output, dict):
        return output['out']
    return output
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\baseline_module.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', seg_loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\baseline_module.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations.

Source code in embeddings_squeeze\models\baseline_module.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )
configure_optimizers
configure_optimizers()

Configure optimizer.

Source code in embeddings_squeeze\models\baseline_module.py
228
229
230
231
232
def configure_optimizers(self):
    """Configure optimizer."""
    # Optimize only trainable params
    params = [p for p in self.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=self.learning_rate)
predict
predict(images)

Predict segmentation masks.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
predictions

Segmentation predictions [B, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def predict(self, images):
    """
    Predict segmentation masks.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        predictions: Segmentation predictions [B, H, W]
    """
    self.eval()
    with torch.no_grad():
        output = self(images)
        predictions = output.argmax(dim=1)
    return predictions
predict_logits
predict_logits(images)

Predict segmentation logits.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
logits

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def predict_logits(self, images):
    """
    Predict segmentation logits.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        logits: Segmentation logits [B, num_classes, H, W]
    """
    self.eval()
    with torch.no_grad():
        logits = self(images)
    return logits

BaseDataModule

BaseDataModule(
    data_path,
    batch_size=4,
    num_workers=0,
    pin_memory=True,
    **kwargs
)

Bases: LightningDataModule, ABC

Abstract base class for data modules.

All dataset-specific data modules should inherit from this class.

Source code in embeddings_squeeze\data\base.py
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(
    self,
    data_path: str,
    batch_size: int = 4,
    num_workers: int = 0,
    pin_memory: bool = True,
    **kwargs
):
    super().__init__()
    self.data_path = data_path
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.pin_memory = pin_memory

Functions

setup abstractmethod
setup(stage=None)

Setup datasets for training/validation/testing.

Parameters:

Name Type Description Default
stage str

'fit', 'validate', 'test', or None

None
Source code in embeddings_squeeze\data\base.py
31
32
33
34
35
36
37
38
39
@abstractmethod
def setup(self, stage: str = None):
    """
    Setup datasets for training/validation/testing.

    Args:
        stage: 'fit', 'validate', 'test', or None
    """
    pass
train_dataloader abstractmethod
train_dataloader(max_batches=None)

Return training dataloader.

Source code in embeddings_squeeze\data\base.py
41
42
43
44
@abstractmethod
def train_dataloader(self, max_batches: int = None):
    """Return training dataloader."""
    pass
val_dataloader abstractmethod
val_dataloader(max_batches=None)

Return validation dataloader.

Source code in embeddings_squeeze\data\base.py
46
47
48
49
@abstractmethod
def val_dataloader(self, max_batches: int = None):
    """Return validation dataloader."""
    pass
test_dataloader abstractmethod
test_dataloader(max_batches=None)

Return test dataloader.

Source code in embeddings_squeeze\data\base.py
51
52
53
54
@abstractmethod
def test_dataloader(self, max_batches: int = None):
    """Return test dataloader."""
    pass

ClearMLLogger

ClearMLLogger(task)

Wrapper for ClearML logging compatible with PyTorch Lightning. Supports scalar metrics, plots, images, and text logging.

Source code in embeddings_squeeze\loggers\clearml_logger.py
124
125
126
def __init__(self, task: Task):
    self.task = task
    self.logger = task.get_logger() if task else None

Functions

log_metrics
log_metrics(metrics, step=None)

Log metrics to ClearML.

Source code in embeddings_squeeze\loggers\clearml_logger.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def log_metrics(self, metrics: dict, step: int = None):
    """Log metrics to ClearML."""
    if self.logger is None:
        return

    for key, value in metrics.items():
        # Split key into title and series (e.g., "train/loss" -> title="train", series="loss")
        if '/' in key:
            title, series = key.split('/', 1)
        else:
            title = 'metrics'
            series = key

        self.logger.report_scalar(
            title=title,
            series=series,
            value=value,
            iteration=step
        )
log_scalar
log_scalar(title, series, value, iteration)

Log a single scalar value to ClearML.

Parameters:

Name Type Description Default
title str

Graph title (e.g., "loss", "accuracy")

required
series str

Series name within the graph (e.g., "train", "val")

required
value float

Scalar value to log

required
iteration int

Iteration/step number

required
Example

logger.log_scalar("loss", "train", 0.5, iteration=100) logger.log_scalar("loss", "val", 0.3, iteration=100)

Source code in embeddings_squeeze\loggers\clearml_logger.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def log_scalar(self, title: str, series: str, value: float, iteration: int):
    """
    Log a single scalar value to ClearML.

    Args:
        title: Graph title (e.g., "loss", "accuracy")
        series: Series name within the graph (e.g., "train", "val")
        value: Scalar value to log
        iteration: Iteration/step number

    Example:
        logger.log_scalar("loss", "train", 0.5, iteration=100)
        logger.log_scalar("loss", "val", 0.3, iteration=100)
    """
    if self.logger is None:
        return

    self.logger.report_scalar(
        title=title,
        series=series,
        value=value,
        iteration=iteration
    )
log_image
log_image(title, series, image, iteration)

Log an image to ClearML.

Parameters:

Name Type Description Default
title str

Image title/group

required
series str

Series name (e.g., "predictions", "ground_truth")

required
image

Image as numpy array (H, W) or (H, W, C) for grayscale/RGB Supports uint8 (0-255) or float (0-1)

required
iteration int

Iteration/step number

required
Example
Grayscale image

img = np.eye(256, 256, dtype=np.uint8) * 255 logger.log_image("predictions", "epoch_1", img, iteration=0)

RGB image

img_rgb = np.zeros((256, 256, 3), dtype=np.uint8) img_rgb[:, :, 0] = 255 # Red channel logger.log_image("predictions", "epoch_1_rgb", img_rgb, iteration=0)

Source code in embeddings_squeeze\loggers\clearml_logger.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def log_image(self, title: str, series: str, image, iteration: int):
    """
    Log an image to ClearML.

    Args:
        title: Image title/group
        series: Series name (e.g., "predictions", "ground_truth")
        image: Image as numpy array (H, W) or (H, W, C) for grayscale/RGB
               Supports uint8 (0-255) or float (0-1)
        iteration: Iteration/step number

    Example:
        # Grayscale image
        img = np.eye(256, 256, dtype=np.uint8) * 255
        logger.log_image("predictions", "epoch_1", img, iteration=0)

        # RGB image
        img_rgb = np.zeros((256, 256, 3), dtype=np.uint8)
        img_rgb[:, :, 0] = 255  # Red channel
        logger.log_image("predictions", "epoch_1_rgb", img_rgb, iteration=0)
    """
    if self.logger is None:
        return

    self.logger.report_image(
        title=title,
        series=series,
        iteration=iteration,
        image=image
    )
log_images_batch
log_images_batch(title, series, images, iteration)

Log multiple images to ClearML.

Parameters:

Name Type Description Default
title str

Image title/group

required
series str

Series name

required
images list

List of images (numpy arrays)

required
iteration int

Iteration/step number

required
Example

images = [img1, img2, img3] logger.log_images_batch("samples", "batch_0", images, iteration=0)

Source code in embeddings_squeeze\loggers\clearml_logger.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def log_images_batch(self, title: str, series: str, images: list, iteration: int):
    """
    Log multiple images to ClearML.

    Args:
        title: Image title/group
        series: Series name
        images: List of images (numpy arrays)
        iteration: Iteration/step number

    Example:
        images = [img1, img2, img3]
        logger.log_images_batch("samples", "batch_0", images, iteration=0)
    """
    if self.logger is None:
        return

    for idx, image in enumerate(images):
        self.logger.report_image(
            title=title,
            series=f"{series}_img_{idx}",
            iteration=iteration,
            image=image
        )
log_text
log_text(text, title='Info')

Log text to ClearML.

Source code in embeddings_squeeze\loggers\clearml_logger.py
228
229
230
231
232
def log_text(self, text: str, title: str = "Info"):
    """Log text to ClearML."""
    if self.logger is None:
        return
    self.logger.report_text(text, print_console=True)
report_text
report_text(text)

Report text to ClearML (alias for log_text with default title).

Source code in embeddings_squeeze\loggers\clearml_logger.py
234
235
236
def report_text(self, text: str):
    """Report text to ClearML (alias for log_text with default title)."""
    self.log_text(text)
report_plotly
report_plotly(title, series, iteration, figure)

Report a Plotly figure to ClearML.

Parameters:

Name Type Description Default
title str

Plot title/group

required
series str

Series name

required
iteration int

Iteration/step number

required
figure

Plotly figure object

required
Example

import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Scatter(x=[1,2,3], y=[4,5,6], mode='lines', name='data')) fig.update_layout(title="My Plot", xaxis_title="x", yaxis_title="y") logger.report_plotly("metrics", "loss", iteration=0, figure=fig)

Source code in embeddings_squeeze\loggers\clearml_logger.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def report_plotly(self, title: str, series: str, iteration: int, figure):
    """
    Report a Plotly figure to ClearML.

    Args:
        title: Plot title/group
        series: Series name
        iteration: Iteration/step number
        figure: Plotly figure object

    Example:
        import plotly.graph_objects as go
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=[1,2,3], y=[4,5,6], mode='lines', name='data'))
        fig.update_layout(title="My Plot", xaxis_title="x", yaxis_title="y")
        logger.report_plotly("metrics", "loss", iteration=0, figure=fig)
    """
    if self.logger is None:
        return

    self.logger.report_plotly(
        title=title,
        series=series,
        iteration=iteration,
        figure=figure
    )
finalize
finalize()

Finalize logging and close task.

Source code in embeddings_squeeze\loggers\clearml_logger.py
265
266
267
268
def finalize(self):
    """Finalize logging and close task."""
    if self.task:
        self.task.close()

Functions

setup_clearml

setup_clearml(project_name, task_name, auto_connect=True)

Setup ClearML with credentials from config file.

Parameters:

Name Type Description Default
project_name str

ClearML project name

required
task_name str

ClearML task name

required
auto_connect bool

If True, automatically connect frameworks

True

Returns:

Type Description

Task object

Source code in embeddings_squeeze\loggers\clearml_logger.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def setup_clearml(project_name: str, task_name: str, auto_connect: bool = True):
    """
    Setup ClearML with credentials from config file.

    Args:
        project_name: ClearML project name
        task_name: ClearML task name
        auto_connect: If True, automatically connect frameworks

    Returns:
        Task object
    """
    # Load credentials
    config_dir = Path(__file__).parent.parent / 'configs'

    try:
        creds = load_credentials(config_dir)

        # Set credentials
        clearml.Task.set_credentials(
            api_host=creds.get('api_host', 'https://api.clear.ml'),
            web_host=creds.get('web_host', 'https://app.clear.ml'),
            files_host=creds.get('files_host', 'https://files.clear.ml'),
            key=creds['api_key'],
            secret=creds['api_secret']
        )

        # Initialize task
        task = Task.init(
            project_name=project_name,
            task_name=task_name,
            auto_connect_frameworks=auto_connect
        )
        return task
    except FileNotFoundError as e:
        print(f"Warning: {e}")
        print("ClearML logging disabled. Using TensorBoard instead.")
        return None
    except Exception as e:
        print(f"Warning: Failed to setup ClearML: {e}")
        print("ClearML logging disabled. Using TensorBoard instead.")
        return None

Modules

cli

CLI module for embeddings_squeeze package.

Functions

squeeze
squeeze()

Entry point for squeeze command.

Source code in embeddings_squeeze\cli.py
14
15
16
def squeeze():
    """Entry point for squeeze command."""
    squeeze_main()

configs

Configuration management for the package.

Classes

ModelConfig dataclass
ModelConfig(
    backbone="vit",
    num_classes=21,
    freeze_backbone=True,
    vit_weights="IMAGENET1K_V1",
    deeplab_weights="COCO_WITH_VOC_LABELS_V1",
    add_adapter=False,
    feature_dim=768,
    loss_type="ce",
    class_weights=None,
)

Model architecture configuration.

TrainingConfig dataclass
TrainingConfig(
    epochs=10,
    batch_size=4,
    max_batches=None,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    num_workers=4,
    pin_memory=True,
    optimizer="adam",
    weight_decay=0.0,
    log_every_n_steps=50,
    val_check_interval=1.0,
    save_top_k=3,
    monitor="val/loss",
    mode="min",
)

Training configuration.

DataConfig dataclass
DataConfig(
    dataset="oxford_pet",
    data_path="./data",
    image_size=224,
    subset_size=None,
    normalize_mean=(0.485, 0.456, 0.406),
    normalize_std=(0.229, 0.224, 0.225),
)

Data configuration.

ExperimentConfig dataclass
ExperimentConfig(
    model=ModelConfig(),
    training=TrainingConfig(),
    data=DataConfig(),
    quantizer=QuantizerConfig(),
    logger=LoggerConfig(),
    experiment_name="vq_squeeze",
    output_dir="./outputs",
    seed=42,
    initialize_codebook=True,
    max_init_samples=50000,
)

Complete experiment configuration.

Functions

get_default_config
get_default_config()

Get default configuration.

Source code in embeddings_squeeze\configs\default.py
122
123
124
def get_default_config() -> ExperimentConfig:
    """Get default configuration."""
    return ExperimentConfig()

Modules

default

Default configuration classes and settings.

Classes
QuantizerConfig dataclass
QuantizerConfig(
    enabled=True,
    type="vq",
    codebook_size=512,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
    levels=(lambda: [8, 5, 5, 5])(),
    entropy_loss_weight=0.1,
    diversity_gamma=0.1,
    spherical=False,
    num_quantizers=4,
)

Quantizer configuration.

LoggerConfig dataclass
LoggerConfig(
    use_clearml=True,
    use_tensorboard=False,
    project_name="embeddings_squeeze",
    task_name="vq_compression",
    credentials_file="clearml_credentials.yaml",
)

Logger configuration.

ModelConfig dataclass
ModelConfig(
    backbone="vit",
    num_classes=21,
    freeze_backbone=True,
    vit_weights="IMAGENET1K_V1",
    deeplab_weights="COCO_WITH_VOC_LABELS_V1",
    add_adapter=False,
    feature_dim=768,
    loss_type="ce",
    class_weights=None,
)

Model architecture configuration.

TrainingConfig dataclass
TrainingConfig(
    epochs=10,
    batch_size=4,
    max_batches=None,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    num_workers=4,
    pin_memory=True,
    optimizer="adam",
    weight_decay=0.0,
    log_every_n_steps=50,
    val_check_interval=1.0,
    save_top_k=3,
    monitor="val/loss",
    mode="min",
)

Training configuration.

DataConfig dataclass
DataConfig(
    dataset="oxford_pet",
    data_path="./data",
    image_size=224,
    subset_size=None,
    normalize_mean=(0.485, 0.456, 0.406),
    normalize_std=(0.229, 0.224, 0.225),
)

Data configuration.

ExperimentConfig dataclass
ExperimentConfig(
    model=ModelConfig(),
    training=TrainingConfig(),
    data=DataConfig(),
    quantizer=QuantizerConfig(),
    logger=LoggerConfig(),
    experiment_name="vq_squeeze",
    output_dir="./outputs",
    seed=42,
    initialize_codebook=True,
    max_init_samples=50000,
)

Complete experiment configuration.

Functions
get_default_config
get_default_config()

Get default configuration.

Source code in embeddings_squeeze\configs\default.py
122
123
124
def get_default_config() -> ExperimentConfig:
    """Get default configuration."""
    return ExperimentConfig()
update_config_from_args
update_config_from_args(config, args)

Update configuration from command line arguments.

Parameters:

Name Type Description Default
config ExperimentConfig

Base configuration

required
args Dict[str, Any]

Command line arguments

required

Returns:

Type Description
ExperimentConfig

Updated configuration

Source code in embeddings_squeeze\configs\default.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def update_config_from_args(config: ExperimentConfig, args: Dict[str, Any]) -> ExperimentConfig:
    """
    Update configuration from command line arguments.

    Args:
        config: Base configuration
        args: Command line arguments

    Returns:
        Updated configuration
    """
    # Model config
    if "model" in args:
        config.model.backbone = args["model"]
    if "num_classes" in args:
        config.model.num_classes = args["num_classes"]
    if "add_adapter" in args:
        config.model.add_adapter = args["add_adapter"]
    if "feature_dim" in args and args["feature_dim"] is not None:
        config.model.feature_dim = args["feature_dim"]
    if "loss_type" in args:
        config.model.loss_type = args["loss_type"]
    if "class_weights" in args:
        config.model.class_weights = args["class_weights"]

    # Quantizer config
    if "quantizer_type" in args:
        config.quantizer.type = args["quantizer_type"]
    if "quantizer_enabled" in args:
        config.quantizer.enabled = args["quantizer_enabled"]
    if "codebook_size" in args:
        config.quantizer.codebook_size = args["codebook_size"]
    if "bottleneck_dim" in args:
        config.quantizer.bottleneck_dim = args["bottleneck_dim"]
    if "num_quantizers" in args:
        config.quantizer.num_quantizers = args["num_quantizers"]

    # Logger config
    if "use_clearml" in args:
        config.logger.use_clearml = args["use_clearml"]
    if "project_name" in args:
        config.logger.project_name = args["project_name"]
    if "task_name" in args:
        config.logger.task_name = args["task_name"]

    # Training config
    if "epochs" in args:
        config.training.epochs = args["epochs"]
    if "batch_size" in args:
        config.training.batch_size = args["batch_size"]
    if "max_batches" in args:
        config.training.max_batches = args["max_batches"]
    if "lr" in args:
        config.training.learning_rate = args["lr"]
    if "vq_loss_weight" in args:
        config.training.vq_loss_weight = args["vq_loss_weight"]

    # Data config
    if "dataset" in args:
        config.data.dataset = args["dataset"]
    if "data_path" in args:
        config.data.data_path = args["data_path"]
    if "subset_size" in args:
        config.data.subset_size = args["subset_size"]

    # Experiment config
    if "output_dir" in args:
        config.output_dir = args["output_dir"]
    if "experiment_name" in args:
        config.experiment_name = args["experiment_name"]
    if "seed" in args:
        config.seed = args["seed"]
    if "initialize_codebook" in args:
        # argparse provides this key with a boolean even if the flag is not passed
        config.initialize_codebook = args["initialize_codebook"]
    if "max_init_samples" in args and args["max_init_samples"] is not None:
        config.max_init_samples = args["max_init_samples"]

    return config

data

Data modules for different datasets.

Classes

BaseDataModule
BaseDataModule(
    data_path,
    batch_size=4,
    num_workers=0,
    pin_memory=True,
    **kwargs
)

Bases: LightningDataModule, ABC

Abstract base class for data modules.

All dataset-specific data modules should inherit from this class.

Source code in embeddings_squeeze\data\base.py
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(
    self,
    data_path: str,
    batch_size: int = 4,
    num_workers: int = 0,
    pin_memory: bool = True,
    **kwargs
):
    super().__init__()
    self.data_path = data_path
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.pin_memory = pin_memory
Functions
setup abstractmethod
setup(stage=None)

Setup datasets for training/validation/testing.

Parameters:

Name Type Description Default
stage str

'fit', 'validate', 'test', or None

None
Source code in embeddings_squeeze\data\base.py
31
32
33
34
35
36
37
38
39
@abstractmethod
def setup(self, stage: str = None):
    """
    Setup datasets for training/validation/testing.

    Args:
        stage: 'fit', 'validate', 'test', or None
    """
    pass
train_dataloader abstractmethod
train_dataloader(max_batches=None)

Return training dataloader.

Source code in embeddings_squeeze\data\base.py
41
42
43
44
@abstractmethod
def train_dataloader(self, max_batches: int = None):
    """Return training dataloader."""
    pass
val_dataloader abstractmethod
val_dataloader(max_batches=None)

Return validation dataloader.

Source code in embeddings_squeeze\data\base.py
46
47
48
49
@abstractmethod
def val_dataloader(self, max_batches: int = None):
    """Return validation dataloader."""
    pass
test_dataloader abstractmethod
test_dataloader(max_batches=None)

Return test dataloader.

Source code in embeddings_squeeze\data\base.py
51
52
53
54
@abstractmethod
def test_dataloader(self, max_batches: int = None):
    """Return test dataloader."""
    pass
OxfordPetDataModule
OxfordPetDataModule(
    data_path="./data",
    batch_size=4,
    num_workers=4,
    pin_memory=True,
    image_size=224,
    subset_size=None,
    **kwargs
)

Bases: BaseDataModule

Data module for Oxford-IIIT Pet segmentation dataset.

Source code in embeddings_squeeze\data\oxford_pet.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    data_path: str = './data',
    batch_size: int = 4,
    num_workers: int = 4,
    pin_memory: bool = True,
    image_size: int = 224,
    subset_size: int = None,
    **kwargs
):
    super().__init__(data_path, batch_size, num_workers, pin_memory, **kwargs)

    self.image_size = image_size
    self.subset_size = subset_size

    # Define transforms
    self.transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    self.transform_mask = transforms.Compose([
        transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
        transforms.CenterCrop(image_size),
        transforms.PILToTensor()
    ])

    # Dataset attributes
    self.train_dataset = None
    self.val_dataset = None
    self.test_dataset = None
Functions
setup
setup(stage=None)

Setup datasets.

Source code in embeddings_squeeze\data\oxford_pet.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def setup(self, stage: str = None):
    """Setup datasets."""
    if stage == 'fit' or stage is None:
        # Check if dataset exists
        pet_path = os.path.join(self.data_path, 'oxford-iiit-pet')
        need_download = not os.path.exists(pet_path)

        # Load full dataset
        pet_dataset = OxfordIIITPet(
            root=self.data_path,
            split='trainval',
            target_types='segmentation',
            download=need_download
        )

        # Wrap with transforms
        wrapped_dataset = PetSegmentationDataset(
            pet_dataset, self.transform_image, self.transform_mask
        )

        # Create subset if specified
        if self.subset_size is not None:
            wrapped_dataset = Subset(wrapped_dataset, range(self.subset_size))

        # Split into train/val (80/20)
        total_size = len(wrapped_dataset)
        train_size = int(0.8 * total_size)

        self.train_dataset = Subset(wrapped_dataset, range(train_size))
        self.val_dataset = Subset(wrapped_dataset, range(train_size, total_size))

    if stage == 'test' or stage is None:
        # Load test dataset
        pet_dataset = OxfordIIITPet(
            root=self.data_path,
            split='test',
            target_types='segmentation',
            download=False
        )

        wrapped_dataset = PetSegmentationDataset(
            pet_dataset, self.transform_image, self.transform_mask
        )

        if self.subset_size is not None:
            wrapped_dataset = Subset(wrapped_dataset, range(min(self.subset_size, len(wrapped_dataset))))

        self.test_dataset = wrapped_dataset
train_dataloader
train_dataloader(max_batches=None)

Return training dataloader.

Source code in embeddings_squeeze\data\oxford_pet.py
120
121
122
123
124
125
126
127
128
129
def train_dataloader(self, max_batches: int = None):
    """Return training dataloader."""
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=max_batches is not None
    )
val_dataloader
val_dataloader(max_batches=None)

Return validation dataloader.

Source code in embeddings_squeeze\data\oxford_pet.py
131
132
133
134
135
136
137
138
139
140
def val_dataloader(self, max_batches: int = None):
    """Return validation dataloader."""
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=max_batches is not None
    )
test_dataloader
test_dataloader(max_batches=None)

Return test dataloader.

Source code in embeddings_squeeze\data\oxford_pet.py
142
143
144
145
146
147
148
149
150
151
def test_dataloader(self, max_batches: int = None):
    """Return test dataloader."""
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=max_batches is not None
    )

Modules

base

Base data module for PyTorch Lightning.

Classes
BaseDataModule
BaseDataModule(
    data_path,
    batch_size=4,
    num_workers=0,
    pin_memory=True,
    **kwargs
)

Bases: LightningDataModule, ABC

Abstract base class for data modules.

All dataset-specific data modules should inherit from this class.

Source code in embeddings_squeeze\data\base.py
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(
    self,
    data_path: str,
    batch_size: int = 4,
    num_workers: int = 0,
    pin_memory: bool = True,
    **kwargs
):
    super().__init__()
    self.data_path = data_path
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.pin_memory = pin_memory
Functions
setup abstractmethod
setup(stage=None)

Setup datasets for training/validation/testing.

Parameters:

Name Type Description Default
stage str

'fit', 'validate', 'test', or None

None
Source code in embeddings_squeeze\data\base.py
31
32
33
34
35
36
37
38
39
@abstractmethod
def setup(self, stage: str = None):
    """
    Setup datasets for training/validation/testing.

    Args:
        stage: 'fit', 'validate', 'test', or None
    """
    pass
train_dataloader abstractmethod
train_dataloader(max_batches=None)

Return training dataloader.

Source code in embeddings_squeeze\data\base.py
41
42
43
44
@abstractmethod
def train_dataloader(self, max_batches: int = None):
    """Return training dataloader."""
    pass
val_dataloader abstractmethod
val_dataloader(max_batches=None)

Return validation dataloader.

Source code in embeddings_squeeze\data\base.py
46
47
48
49
@abstractmethod
def val_dataloader(self, max_batches: int = None):
    """Return validation dataloader."""
    pass
test_dataloader abstractmethod
test_dataloader(max_batches=None)

Return test dataloader.

Source code in embeddings_squeeze\data\base.py
51
52
53
54
@abstractmethod
def test_dataloader(self, max_batches: int = None):
    """Return test dataloader."""
    pass
oxford_pet

Oxford-IIIT Pet dataset data module.

Classes
PetSegmentationDataset
PetSegmentationDataset(
    pet_dataset, transform_image, transform_mask
)

Wrapper for Oxford-IIIT Pet dataset with proper transforms.

Source code in embeddings_squeeze\data\oxford_pet.py
17
18
19
20
def __init__(self, pet_dataset, transform_image, transform_mask):
    self.dataset = pet_dataset
    self.transform_image = transform_image
    self.transform_mask = transform_mask
OxfordPetDataModule
OxfordPetDataModule(
    data_path="./data",
    batch_size=4,
    num_workers=4,
    pin_memory=True,
    image_size=224,
    subset_size=None,
    **kwargs
)

Bases: BaseDataModule

Data module for Oxford-IIIT Pet segmentation dataset.

Source code in embeddings_squeeze\data\oxford_pet.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    data_path: str = './data',
    batch_size: int = 4,
    num_workers: int = 4,
    pin_memory: bool = True,
    image_size: int = 224,
    subset_size: int = None,
    **kwargs
):
    super().__init__(data_path, batch_size, num_workers, pin_memory, **kwargs)

    self.image_size = image_size
    self.subset_size = subset_size

    # Define transforms
    self.transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    self.transform_mask = transforms.Compose([
        transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
        transforms.CenterCrop(image_size),
        transforms.PILToTensor()
    ])

    # Dataset attributes
    self.train_dataset = None
    self.val_dataset = None
    self.test_dataset = None
Functions
setup
setup(stage=None)

Setup datasets.

Source code in embeddings_squeeze\data\oxford_pet.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def setup(self, stage: str = None):
    """Setup datasets."""
    if stage == 'fit' or stage is None:
        # Check if dataset exists
        pet_path = os.path.join(self.data_path, 'oxford-iiit-pet')
        need_download = not os.path.exists(pet_path)

        # Load full dataset
        pet_dataset = OxfordIIITPet(
            root=self.data_path,
            split='trainval',
            target_types='segmentation',
            download=need_download
        )

        # Wrap with transforms
        wrapped_dataset = PetSegmentationDataset(
            pet_dataset, self.transform_image, self.transform_mask
        )

        # Create subset if specified
        if self.subset_size is not None:
            wrapped_dataset = Subset(wrapped_dataset, range(self.subset_size))

        # Split into train/val (80/20)
        total_size = len(wrapped_dataset)
        train_size = int(0.8 * total_size)

        self.train_dataset = Subset(wrapped_dataset, range(train_size))
        self.val_dataset = Subset(wrapped_dataset, range(train_size, total_size))

    if stage == 'test' or stage is None:
        # Load test dataset
        pet_dataset = OxfordIIITPet(
            root=self.data_path,
            split='test',
            target_types='segmentation',
            download=False
        )

        wrapped_dataset = PetSegmentationDataset(
            pet_dataset, self.transform_image, self.transform_mask
        )

        if self.subset_size is not None:
            wrapped_dataset = Subset(wrapped_dataset, range(min(self.subset_size, len(wrapped_dataset))))

        self.test_dataset = wrapped_dataset
train_dataloader
train_dataloader(max_batches=None)

Return training dataloader.

Source code in embeddings_squeeze\data\oxford_pet.py
120
121
122
123
124
125
126
127
128
129
def train_dataloader(self, max_batches: int = None):
    """Return training dataloader."""
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=max_batches is not None
    )
val_dataloader
val_dataloader(max_batches=None)

Return validation dataloader.

Source code in embeddings_squeeze\data\oxford_pet.py
131
132
133
134
135
136
137
138
139
140
def val_dataloader(self, max_batches: int = None):
    """Return validation dataloader."""
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=max_batches is not None
    )
test_dataloader
test_dataloader(max_batches=None)

Return test dataloader.

Source code in embeddings_squeeze\data\oxford_pet.py
142
143
144
145
146
147
148
149
150
151
def test_dataloader(self, max_batches: int = None):
    """Return test dataloader."""
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=max_batches is not None
    )

loggers

Logger integrations for embeddings_squeeze.

Classes

ClearMLLogger
ClearMLLogger(task)

Wrapper for ClearML logging compatible with PyTorch Lightning. Supports scalar metrics, plots, images, and text logging.

Source code in embeddings_squeeze\loggers\clearml_logger.py
124
125
126
def __init__(self, task: Task):
    self.task = task
    self.logger = task.get_logger() if task else None
Functions
log_metrics
log_metrics(metrics, step=None)

Log metrics to ClearML.

Source code in embeddings_squeeze\loggers\clearml_logger.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def log_metrics(self, metrics: dict, step: int = None):
    """Log metrics to ClearML."""
    if self.logger is None:
        return

    for key, value in metrics.items():
        # Split key into title and series (e.g., "train/loss" -> title="train", series="loss")
        if '/' in key:
            title, series = key.split('/', 1)
        else:
            title = 'metrics'
            series = key

        self.logger.report_scalar(
            title=title,
            series=series,
            value=value,
            iteration=step
        )
log_scalar
log_scalar(title, series, value, iteration)

Log a single scalar value to ClearML.

Parameters:

Name Type Description Default
title str

Graph title (e.g., "loss", "accuracy")

required
series str

Series name within the graph (e.g., "train", "val")

required
value float

Scalar value to log

required
iteration int

Iteration/step number

required
Example

logger.log_scalar("loss", "train", 0.5, iteration=100) logger.log_scalar("loss", "val", 0.3, iteration=100)

Source code in embeddings_squeeze\loggers\clearml_logger.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def log_scalar(self, title: str, series: str, value: float, iteration: int):
    """
    Log a single scalar value to ClearML.

    Args:
        title: Graph title (e.g., "loss", "accuracy")
        series: Series name within the graph (e.g., "train", "val")
        value: Scalar value to log
        iteration: Iteration/step number

    Example:
        logger.log_scalar("loss", "train", 0.5, iteration=100)
        logger.log_scalar("loss", "val", 0.3, iteration=100)
    """
    if self.logger is None:
        return

    self.logger.report_scalar(
        title=title,
        series=series,
        value=value,
        iteration=iteration
    )
log_image
log_image(title, series, image, iteration)

Log an image to ClearML.

Parameters:

Name Type Description Default
title str

Image title/group

required
series str

Series name (e.g., "predictions", "ground_truth")

required
image

Image as numpy array (H, W) or (H, W, C) for grayscale/RGB Supports uint8 (0-255) or float (0-1)

required
iteration int

Iteration/step number

required
Example
Grayscale image

img = np.eye(256, 256, dtype=np.uint8) * 255 logger.log_image("predictions", "epoch_1", img, iteration=0)

RGB image

img_rgb = np.zeros((256, 256, 3), dtype=np.uint8) img_rgb[:, :, 0] = 255 # Red channel logger.log_image("predictions", "epoch_1_rgb", img_rgb, iteration=0)

Source code in embeddings_squeeze\loggers\clearml_logger.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def log_image(self, title: str, series: str, image, iteration: int):
    """
    Log an image to ClearML.

    Args:
        title: Image title/group
        series: Series name (e.g., "predictions", "ground_truth")
        image: Image as numpy array (H, W) or (H, W, C) for grayscale/RGB
               Supports uint8 (0-255) or float (0-1)
        iteration: Iteration/step number

    Example:
        # Grayscale image
        img = np.eye(256, 256, dtype=np.uint8) * 255
        logger.log_image("predictions", "epoch_1", img, iteration=0)

        # RGB image
        img_rgb = np.zeros((256, 256, 3), dtype=np.uint8)
        img_rgb[:, :, 0] = 255  # Red channel
        logger.log_image("predictions", "epoch_1_rgb", img_rgb, iteration=0)
    """
    if self.logger is None:
        return

    self.logger.report_image(
        title=title,
        series=series,
        iteration=iteration,
        image=image
    )
log_images_batch
log_images_batch(title, series, images, iteration)

Log multiple images to ClearML.

Parameters:

Name Type Description Default
title str

Image title/group

required
series str

Series name

required
images list

List of images (numpy arrays)

required
iteration int

Iteration/step number

required
Example

images = [img1, img2, img3] logger.log_images_batch("samples", "batch_0", images, iteration=0)

Source code in embeddings_squeeze\loggers\clearml_logger.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def log_images_batch(self, title: str, series: str, images: list, iteration: int):
    """
    Log multiple images to ClearML.

    Args:
        title: Image title/group
        series: Series name
        images: List of images (numpy arrays)
        iteration: Iteration/step number

    Example:
        images = [img1, img2, img3]
        logger.log_images_batch("samples", "batch_0", images, iteration=0)
    """
    if self.logger is None:
        return

    for idx, image in enumerate(images):
        self.logger.report_image(
            title=title,
            series=f"{series}_img_{idx}",
            iteration=iteration,
            image=image
        )
log_text
log_text(text, title='Info')

Log text to ClearML.

Source code in embeddings_squeeze\loggers\clearml_logger.py
228
229
230
231
232
def log_text(self, text: str, title: str = "Info"):
    """Log text to ClearML."""
    if self.logger is None:
        return
    self.logger.report_text(text, print_console=True)
report_text
report_text(text)

Report text to ClearML (alias for log_text with default title).

Source code in embeddings_squeeze\loggers\clearml_logger.py
234
235
236
def report_text(self, text: str):
    """Report text to ClearML (alias for log_text with default title)."""
    self.log_text(text)
report_plotly
report_plotly(title, series, iteration, figure)

Report a Plotly figure to ClearML.

Parameters:

Name Type Description Default
title str

Plot title/group

required
series str

Series name

required
iteration int

Iteration/step number

required
figure

Plotly figure object

required
Example

import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Scatter(x=[1,2,3], y=[4,5,6], mode='lines', name='data')) fig.update_layout(title="My Plot", xaxis_title="x", yaxis_title="y") logger.report_plotly("metrics", "loss", iteration=0, figure=fig)

Source code in embeddings_squeeze\loggers\clearml_logger.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def report_plotly(self, title: str, series: str, iteration: int, figure):
    """
    Report a Plotly figure to ClearML.

    Args:
        title: Plot title/group
        series: Series name
        iteration: Iteration/step number
        figure: Plotly figure object

    Example:
        import plotly.graph_objects as go
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=[1,2,3], y=[4,5,6], mode='lines', name='data'))
        fig.update_layout(title="My Plot", xaxis_title="x", yaxis_title="y")
        logger.report_plotly("metrics", "loss", iteration=0, figure=fig)
    """
    if self.logger is None:
        return

    self.logger.report_plotly(
        title=title,
        series=series,
        iteration=iteration,
        figure=figure
    )
finalize
finalize()

Finalize logging and close task.

Source code in embeddings_squeeze\loggers\clearml_logger.py
265
266
267
268
def finalize(self):
    """Finalize logging and close task."""
    if self.task:
        self.task.close()
ClearMLUploadCallback
ClearMLUploadCallback(
    task,
    clearml_logger=None,
    checkpoint_dir="checkpoints",
    embedding_dir="embeddings",
)

Bases: Callback

PyTorch Lightning callback for logging checkpoint and embedding paths to ClearML.

Automatically logs local file paths for: - Latest checkpoint after each validation epoch - Per-epoch validation embeddings

Usage

from pytorch_lightning import Trainer from embeddings_squeeze.loggers import ClearMLUploadCallback, setup_clearml

task = setup_clearml(project_name="my_project", task_name="experiment_1") logger = ClearMLLogger(task) if task else None callback = ClearMLUploadCallback(task, logger, checkpoint_dir="checkpoints")

trainer = Trainer(callbacks=[callback], ...)

Initialize ClearML path logging callback.

Parameters:

Name Type Description Default
task Task

ClearML Task object

required
clearml_logger ClearMLLogger

ClearML logger for text reporting (optional)

None
checkpoint_dir str

Directory containing checkpoints

'checkpoints'
embedding_dir str

Directory containing embeddings

'embeddings'
Source code in embeddings_squeeze\loggers\clearml_logger.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def __init__(self, task: Task, clearml_logger: ClearMLLogger = None, 
             checkpoint_dir: str = "checkpoints", embedding_dir: str = "embeddings"):
    """
    Initialize ClearML path logging callback.

    Args:
        task: ClearML Task object
        clearml_logger: ClearML logger for text reporting (optional)
        checkpoint_dir: Directory containing checkpoints
        embedding_dir: Directory containing embeddings
    """
    super().__init__()
    self.task = task
    self.clearml_logger = clearml_logger
    self.checkpoint_dir = checkpoint_dir
    self.embedding_dir = embedding_dir
Functions
on_validation_epoch_end
on_validation_epoch_end(trainer, pl_module)

Called after validation epoch ends.

Source code in embeddings_squeeze\loggers\clearml_logger.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def on_validation_epoch_end(self, trainer, pl_module):
    """Called after validation epoch ends."""
    if self.task is None:
        return

    # Log checkpoint path
    try:
        ckpt_path = self._find_latest_checkpoint()
        if ckpt_path:
            abs_path = os.path.abspath(ckpt_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Checkpoint saved: {abs_path}")
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed finding checkpoint: {e}")

    # Log embedding path
    try:
        emb_path = os.path.join(
            self.embedding_dir, 
            f"val_embedding_epoch{pl_module.current_epoch}.pt"
        )
        if os.path.exists(emb_path):
            abs_path = os.path.abspath(emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Embedding saved: {abs_path}")
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed logging embedding path: {e}")

Functions

setup_clearml
setup_clearml(project_name, task_name, auto_connect=True)

Setup ClearML with credentials from config file.

Parameters:

Name Type Description Default
project_name str

ClearML project name

required
task_name str

ClearML task name

required
auto_connect bool

If True, automatically connect frameworks

True

Returns:

Type Description

Task object

Source code in embeddings_squeeze\loggers\clearml_logger.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def setup_clearml(project_name: str, task_name: str, auto_connect: bool = True):
    """
    Setup ClearML with credentials from config file.

    Args:
        project_name: ClearML project name
        task_name: ClearML task name
        auto_connect: If True, automatically connect frameworks

    Returns:
        Task object
    """
    # Load credentials
    config_dir = Path(__file__).parent.parent / 'configs'

    try:
        creds = load_credentials(config_dir)

        # Set credentials
        clearml.Task.set_credentials(
            api_host=creds.get('api_host', 'https://api.clear.ml'),
            web_host=creds.get('web_host', 'https://app.clear.ml'),
            files_host=creds.get('files_host', 'https://files.clear.ml'),
            key=creds['api_key'],
            secret=creds['api_secret']
        )

        # Initialize task
        task = Task.init(
            project_name=project_name,
            task_name=task_name,
            auto_connect_frameworks=auto_connect
        )
        return task
    except FileNotFoundError as e:
        print(f"Warning: {e}")
        print("ClearML logging disabled. Using TensorBoard instead.")
        return None
    except Exception as e:
        print(f"Warning: Failed to setup ClearML: {e}")
        print("ClearML logging disabled. Using TensorBoard instead.")
        return None

Modules

clearml_logger

ClearML logger integration with credentials management.

Usage Examples
Setup ClearML

task = setup_clearml(project_name="my_project", task_name="experiment_1") logger = ClearMLLogger(task)

Log scalar metrics (creates unified graphs)

for i in range(100): logger.log_scalar("loss", "train", 1.0/(i+1), iteration=i) logger.log_scalar("loss", "val", 0.5/(i+1), iteration=i)

Log images (grayscale)

import numpy as np img = np.eye(256, 256, dtype=np.uint8) * 255 logger.log_image("predictions", "sample_1", img, iteration=0)

Log RGB images

img_rgb = np.zeros((256, 256, 3), dtype=np.uint8) img_rgb[:, :, 0] = 255 # Red channel logger.log_image("predictions", "sample_rgb", img_rgb, iteration=0)

Log multiple images at once

images = [img1, img2, img3] logger.log_images_batch("batch_samples", "epoch_1", images, iteration=0)

Log text

logger.log_text("Training started successfully!")

Finalize

logger.finalize()

Classes
ClearMLLogger
ClearMLLogger(task)

Wrapper for ClearML logging compatible with PyTorch Lightning. Supports scalar metrics, plots, images, and text logging.

Source code in embeddings_squeeze\loggers\clearml_logger.py
124
125
126
def __init__(self, task: Task):
    self.task = task
    self.logger = task.get_logger() if task else None
Functions
log_metrics
log_metrics(metrics, step=None)

Log metrics to ClearML.

Source code in embeddings_squeeze\loggers\clearml_logger.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def log_metrics(self, metrics: dict, step: int = None):
    """Log metrics to ClearML."""
    if self.logger is None:
        return

    for key, value in metrics.items():
        # Split key into title and series (e.g., "train/loss" -> title="train", series="loss")
        if '/' in key:
            title, series = key.split('/', 1)
        else:
            title = 'metrics'
            series = key

        self.logger.report_scalar(
            title=title,
            series=series,
            value=value,
            iteration=step
        )
log_scalar
log_scalar(title, series, value, iteration)

Log a single scalar value to ClearML.

Parameters:

Name Type Description Default
title str

Graph title (e.g., "loss", "accuracy")

required
series str

Series name within the graph (e.g., "train", "val")

required
value float

Scalar value to log

required
iteration int

Iteration/step number

required
Example

logger.log_scalar("loss", "train", 0.5, iteration=100) logger.log_scalar("loss", "val", 0.3, iteration=100)

Source code in embeddings_squeeze\loggers\clearml_logger.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def log_scalar(self, title: str, series: str, value: float, iteration: int):
    """
    Log a single scalar value to ClearML.

    Args:
        title: Graph title (e.g., "loss", "accuracy")
        series: Series name within the graph (e.g., "train", "val")
        value: Scalar value to log
        iteration: Iteration/step number

    Example:
        logger.log_scalar("loss", "train", 0.5, iteration=100)
        logger.log_scalar("loss", "val", 0.3, iteration=100)
    """
    if self.logger is None:
        return

    self.logger.report_scalar(
        title=title,
        series=series,
        value=value,
        iteration=iteration
    )
log_image
log_image(title, series, image, iteration)

Log an image to ClearML.

Parameters:

Name Type Description Default
title str

Image title/group

required
series str

Series name (e.g., "predictions", "ground_truth")

required
image

Image as numpy array (H, W) or (H, W, C) for grayscale/RGB Supports uint8 (0-255) or float (0-1)

required
iteration int

Iteration/step number

required
Example
Grayscale image

img = np.eye(256, 256, dtype=np.uint8) * 255 logger.log_image("predictions", "epoch_1", img, iteration=0)

RGB image

img_rgb = np.zeros((256, 256, 3), dtype=np.uint8) img_rgb[:, :, 0] = 255 # Red channel logger.log_image("predictions", "epoch_1_rgb", img_rgb, iteration=0)

Source code in embeddings_squeeze\loggers\clearml_logger.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def log_image(self, title: str, series: str, image, iteration: int):
    """
    Log an image to ClearML.

    Args:
        title: Image title/group
        series: Series name (e.g., "predictions", "ground_truth")
        image: Image as numpy array (H, W) or (H, W, C) for grayscale/RGB
               Supports uint8 (0-255) or float (0-1)
        iteration: Iteration/step number

    Example:
        # Grayscale image
        img = np.eye(256, 256, dtype=np.uint8) * 255
        logger.log_image("predictions", "epoch_1", img, iteration=0)

        # RGB image
        img_rgb = np.zeros((256, 256, 3), dtype=np.uint8)
        img_rgb[:, :, 0] = 255  # Red channel
        logger.log_image("predictions", "epoch_1_rgb", img_rgb, iteration=0)
    """
    if self.logger is None:
        return

    self.logger.report_image(
        title=title,
        series=series,
        iteration=iteration,
        image=image
    )
log_images_batch
log_images_batch(title, series, images, iteration)

Log multiple images to ClearML.

Parameters:

Name Type Description Default
title str

Image title/group

required
series str

Series name

required
images list

List of images (numpy arrays)

required
iteration int

Iteration/step number

required
Example

images = [img1, img2, img3] logger.log_images_batch("samples", "batch_0", images, iteration=0)

Source code in embeddings_squeeze\loggers\clearml_logger.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def log_images_batch(self, title: str, series: str, images: list, iteration: int):
    """
    Log multiple images to ClearML.

    Args:
        title: Image title/group
        series: Series name
        images: List of images (numpy arrays)
        iteration: Iteration/step number

    Example:
        images = [img1, img2, img3]
        logger.log_images_batch("samples", "batch_0", images, iteration=0)
    """
    if self.logger is None:
        return

    for idx, image in enumerate(images):
        self.logger.report_image(
            title=title,
            series=f"{series}_img_{idx}",
            iteration=iteration,
            image=image
        )
log_text
log_text(text, title='Info')

Log text to ClearML.

Source code in embeddings_squeeze\loggers\clearml_logger.py
228
229
230
231
232
def log_text(self, text: str, title: str = "Info"):
    """Log text to ClearML."""
    if self.logger is None:
        return
    self.logger.report_text(text, print_console=True)
report_text
report_text(text)

Report text to ClearML (alias for log_text with default title).

Source code in embeddings_squeeze\loggers\clearml_logger.py
234
235
236
def report_text(self, text: str):
    """Report text to ClearML (alias for log_text with default title)."""
    self.log_text(text)
report_plotly
report_plotly(title, series, iteration, figure)

Report a Plotly figure to ClearML.

Parameters:

Name Type Description Default
title str

Plot title/group

required
series str

Series name

required
iteration int

Iteration/step number

required
figure

Plotly figure object

required
Example

import plotly.graph_objects as go fig = go.Figure() fig.add_trace(go.Scatter(x=[1,2,3], y=[4,5,6], mode='lines', name='data')) fig.update_layout(title="My Plot", xaxis_title="x", yaxis_title="y") logger.report_plotly("metrics", "loss", iteration=0, figure=fig)

Source code in embeddings_squeeze\loggers\clearml_logger.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def report_plotly(self, title: str, series: str, iteration: int, figure):
    """
    Report a Plotly figure to ClearML.

    Args:
        title: Plot title/group
        series: Series name
        iteration: Iteration/step number
        figure: Plotly figure object

    Example:
        import plotly.graph_objects as go
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=[1,2,3], y=[4,5,6], mode='lines', name='data'))
        fig.update_layout(title="My Plot", xaxis_title="x", yaxis_title="y")
        logger.report_plotly("metrics", "loss", iteration=0, figure=fig)
    """
    if self.logger is None:
        return

    self.logger.report_plotly(
        title=title,
        series=series,
        iteration=iteration,
        figure=figure
    )
finalize
finalize()

Finalize logging and close task.

Source code in embeddings_squeeze\loggers\clearml_logger.py
265
266
267
268
def finalize(self):
    """Finalize logging and close task."""
    if self.task:
        self.task.close()
ClearMLUploadCallback
ClearMLUploadCallback(
    task,
    clearml_logger=None,
    checkpoint_dir="checkpoints",
    embedding_dir="embeddings",
)

Bases: Callback

PyTorch Lightning callback for logging checkpoint and embedding paths to ClearML.

Automatically logs local file paths for: - Latest checkpoint after each validation epoch - Per-epoch validation embeddings

Usage

from pytorch_lightning import Trainer from embeddings_squeeze.loggers import ClearMLUploadCallback, setup_clearml

task = setup_clearml(project_name="my_project", task_name="experiment_1") logger = ClearMLLogger(task) if task else None callback = ClearMLUploadCallback(task, logger, checkpoint_dir="checkpoints")

trainer = Trainer(callbacks=[callback], ...)

Initialize ClearML path logging callback.

Parameters:

Name Type Description Default
task Task

ClearML Task object

required
clearml_logger ClearMLLogger

ClearML logger for text reporting (optional)

None
checkpoint_dir str

Directory containing checkpoints

'checkpoints'
embedding_dir str

Directory containing embeddings

'embeddings'
Source code in embeddings_squeeze\loggers\clearml_logger.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def __init__(self, task: Task, clearml_logger: ClearMLLogger = None, 
             checkpoint_dir: str = "checkpoints", embedding_dir: str = "embeddings"):
    """
    Initialize ClearML path logging callback.

    Args:
        task: ClearML Task object
        clearml_logger: ClearML logger for text reporting (optional)
        checkpoint_dir: Directory containing checkpoints
        embedding_dir: Directory containing embeddings
    """
    super().__init__()
    self.task = task
    self.clearml_logger = clearml_logger
    self.checkpoint_dir = checkpoint_dir
    self.embedding_dir = embedding_dir
Functions
on_validation_epoch_end
on_validation_epoch_end(trainer, pl_module)

Called after validation epoch ends.

Source code in embeddings_squeeze\loggers\clearml_logger.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def on_validation_epoch_end(self, trainer, pl_module):
    """Called after validation epoch ends."""
    if self.task is None:
        return

    # Log checkpoint path
    try:
        ckpt_path = self._find_latest_checkpoint()
        if ckpt_path:
            abs_path = os.path.abspath(ckpt_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Checkpoint saved: {abs_path}")
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed finding checkpoint: {e}")

    # Log embedding path
    try:
        emb_path = os.path.join(
            self.embedding_dir, 
            f"val_embedding_epoch{pl_module.current_epoch}.pt"
        )
        if os.path.exists(emb_path):
            abs_path = os.path.abspath(emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Embedding saved: {abs_path}")
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed logging embedding path: {e}")
Functions
load_credentials
load_credentials(config_dir=None)

Load ClearML credentials from YAML file.

Parameters:

Name Type Description Default
config_dir str

Directory containing clearml_credentials.yaml If None, uses the configs directory in embeddings_squeeze

None

Returns:

Name Type Description
dict

Credentials dictionary

Source code in embeddings_squeeze\loggers\clearml_logger.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def load_credentials(config_dir: str = None):
    """
    Load ClearML credentials from YAML file.

    Args:
        config_dir: Directory containing clearml_credentials.yaml
                   If None, uses the configs directory in embeddings_squeeze

    Returns:
        dict: Credentials dictionary
    """
    if config_dir is None:
        # Default to embeddings_squeeze/configs
        current_file = Path(__file__)
        config_dir = current_file.parent.parent / 'configs'
    else:
        config_dir = Path(config_dir)

    creds_file = config_dir / 'clearml_credentials.yaml'

    if not creds_file.exists():
        raise FileNotFoundError(
            f"ClearML credentials file not found: {creds_file}\n"
            f"Please copy clearml_credentials.yaml.example to clearml_credentials.yaml "
            f"and fill in your credentials."
        )

    creds = OmegaConf.load(creds_file)
    return OmegaConf.to_container(creds, resolve=True)
setup_clearml
setup_clearml(project_name, task_name, auto_connect=True)

Setup ClearML with credentials from config file.

Parameters:

Name Type Description Default
project_name str

ClearML project name

required
task_name str

ClearML task name

required
auto_connect bool

If True, automatically connect frameworks

True

Returns:

Type Description

Task object

Source code in embeddings_squeeze\loggers\clearml_logger.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def setup_clearml(project_name: str, task_name: str, auto_connect: bool = True):
    """
    Setup ClearML with credentials from config file.

    Args:
        project_name: ClearML project name
        task_name: ClearML task name
        auto_connect: If True, automatically connect frameworks

    Returns:
        Task object
    """
    # Load credentials
    config_dir = Path(__file__).parent.parent / 'configs'

    try:
        creds = load_credentials(config_dir)

        # Set credentials
        clearml.Task.set_credentials(
            api_host=creds.get('api_host', 'https://api.clear.ml'),
            web_host=creds.get('web_host', 'https://app.clear.ml'),
            files_host=creds.get('files_host', 'https://files.clear.ml'),
            key=creds['api_key'],
            secret=creds['api_secret']
        )

        # Initialize task
        task = Task.init(
            project_name=project_name,
            task_name=task_name,
            auto_connect_frameworks=auto_connect
        )
        return task
    except FileNotFoundError as e:
        print(f"Warning: {e}")
        print("ClearML logging disabled. Using TensorBoard instead.")
        return None
    except Exception as e:
        print(f"Warning: Failed to setup ClearML: {e}")
        print("ClearML logging disabled. Using TensorBoard instead.")
        return None

models

Model architectures and components.

Classes

VQWithProjection
VQWithProjection(
    input_dim,
    codebook_size=512,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Vector Quantization (VQ-VAE) with projections

Uses EMA for codebook updates (no gradients needed for codebook) ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection (e.g., 2048 -> 64)
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Vector Quantization
    self.vq = VectorQuantize(
        dim=bottleneck_dim,
        codebook_size=codebook_size,
        decay=decay,  # EMA decay for codebook
        commitment_weight=commitment_weight  # Commitment loss weight
    )

    # Up projection (64 -> 2048)
    self.project_out = nn.Linear(bottleneck_dim, input_dim)
FSQWithProjection
FSQWithProjection(input_dim, levels=None)

Bases: BaseQuantizer

Finite Scalar Quantization (FSQ)

Quantization without codebook - each dimension quantized independently ~10 bits per vector at levels=[8,5,5,5]

Source code in embeddings_squeeze\models\quantizers.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(self, input_dim: int, levels: list = None):
    super().__init__(input_dim)
    if levels is None:
        levels = [8, 5, 5, 5]  # 8*5*5*5 = 1000 codes ≈ 2^10

    self.num_levels = len(levels)

    # Projection to quantization space
    self.project_in = nn.Linear(input_dim, self.num_levels)

    # FSQ quantization
    self.fsq = FSQ(levels=levels, dim=self.num_levels)

    # Projection back
    self.project_out = nn.Linear(self.num_levels, input_dim)
LFQWithProjection
LFQWithProjection(
    input_dim,
    codebook_size=512,
    entropy_loss_weight=0.1,
    diversity_gamma=0.1,
    spherical=False,
)

Bases: BaseQuantizer

Lookup-Free Quantization (LFQ)

Uses entropy loss for code diversity ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512,
    entropy_loss_weight: float = 0.1, 
    diversity_gamma: float = 0.1, 
    spherical: bool = False
):
    super().__init__(input_dim)
    # Quantization dimension = log2(codebook_size)
    self.quant_dim = int(math.log2(codebook_size))

    # Projection with normalization
    self.project_in = nn.Sequential(
        nn.Linear(input_dim, self.quant_dim),
        nn.LayerNorm(self.quant_dim)
    )

    # LFQ quantization
    self.lfq = LFQ(
        dim=self.quant_dim,
        codebook_size=codebook_size,
        entropy_loss_weight=entropy_loss_weight,
        diversity_gamma=diversity_gamma,
        spherical=spherical
    )

    # Projection back
    self.project_out = nn.Linear(self.quant_dim, input_dim)
ResidualVQWithProjection
ResidualVQWithProjection(
    input_dim,
    num_quantizers=4,
    codebook_size=256,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Residual Vector Quantization (RVQ)

Multi-level quantization - each level quantizes the residual of the previous 32 bits per vector at num_quantizers=4, codebook_size=256 (4*8 bits)

Source code in embeddings_squeeze\models\quantizers.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self, 
    input_dim: int, 
    num_quantizers: int = 4,
    codebook_size: int = 256, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Residual VQ
    self.residual_vq = ResidualVQ(
        dim=bottleneck_dim,
        num_quantizers=num_quantizers,  # Number of levels
        codebook_size=codebook_size,
        decay=decay,
        commitment_weight=commitment_weight
    )

    # Up projection
    self.project_out = nn.Linear(bottleneck_dim, input_dim)
BaseQuantizer
BaseQuantizer(input_dim)

Bases: Module

Base class for all quantizers

Source code in embeddings_squeeze\models\quantizers.py
14
15
16
def __init__(self, input_dim: int):
    super().__init__()
    self.input_dim = input_dim
Functions
quantize_spatial
quantize_spatial(features)

Quantize spatial features [B, C, H, W]

Parameters:

Name Type Description Default
features Tensor

Tensor of shape [B, C, H, W]

required

Returns:

Name Type Description
quantized

Quantized features [B, C, H, W]

loss

Quantization loss (scalar)

Source code in embeddings_squeeze\models\quantizers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def quantize_spatial(self, features: torch.Tensor):
    """
    Quantize spatial features [B, C, H, W]

    Args:
        features: Tensor of shape [B, C, H, W]

    Returns:
        quantized: Quantized features [B, C, H, W]
        loss: Quantization loss (scalar)
    """
    B, C, H, W = features.shape
    # Transform [B, C, H, W] -> [B, H*W, C]
    seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)

    # Quantize
    quantized, indices, loss = self.forward(seq)

    # Transform back [B, H*W, C] -> [B, C, H, W]
    quantized = quantized.reshape(B, H, W, C).permute(0, 3, 1, 2)

    # Handle loss (may be tensor with multiple elements)
    if isinstance(loss, torch.Tensor) and loss.numel() > 1:
        loss = loss.mean()

    return quantized, loss
DiceLoss
DiceLoss(smooth=1.0)

Bases: Module

Dice Loss for multi-class segmentation

Source code in embeddings_squeeze\models\losses.py
13
14
15
def __init__(self, smooth: float = 1.0):
    super().__init__()
    self.smooth = smooth
FocalLoss
FocalLoss(alpha=1.0, gamma=2.0, reduction='mean')

Bases: Module

Focal Loss for handling class imbalance (multi-class via CE per-pixel)

Source code in embeddings_squeeze\models\losses.py
40
41
42
43
44
def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction
CombinedLoss
CombinedLoss(
    ce_weight=1.0,
    dice_weight=1.0,
    focal_weight=0.5,
    class_weights=None,
)

Bases: Module

Combined loss: CE + Dice + Focal. Returns (total, ce, dice, focal).

Source code in embeddings_squeeze\models\losses.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, 
    ce_weight: float = 1.0, 
    dice_weight: float = 1.0, 
    focal_weight: float = 0.5, 
    class_weights=None
):
    super().__init__()
    # class_weights can be None or a tensor/list
    if class_weights is not None:
        # Leave tensor creation to forward (to place on correct device) but store raw
        self._class_weights = class_weights
    else:
        self._class_weights = None

    self.ce_weight = ce_weight
    self.dice_weight = dice_weight
    self.focal_weight = focal_weight

    # Instantiate component losses
    self.dice_loss = DiceLoss()
    self.focal_loss = FocalLoss()
SegmentationBackbone
SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()
Attributes
feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions
extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass
VQSqueezeModule
VQSqueezeModule(
    backbone,
    quantizer=None,
    num_classes=21,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    loss_type="ce",
    class_weights=None,
    add_adapter=False,
    feature_dim=2048,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for VQ compression training.

Features: - Multiple quantizer support (VQ, FSQ, LFQ, RVQ) - Adapter layers for fine-tuning frozen backbones - Advanced loss functions (CE, Dice, Focal, Combined) - Embedding extraction and saving

Source code in embeddings_squeeze\models\lightning_module.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    backbone: SegmentationBackbone,
    quantizer: Optional[nn.Module] = None,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    vq_loss_weight: float = 0.1,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    add_adapter: bool = False,
    feature_dim: int = 2048,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'quantizer', 'clearml_logger'])

    self.num_classes = num_classes
    self.learning_rate = learning_rate
    self.vq_loss_weight = vq_loss_weight
    self.loss_type = loss_type
    self.add_adapter = add_adapter
    self.feature_dim = feature_dim

    # Setup backbone with optional adapters
    self.backbone = backbone
    self._setup_backbone_with_adapters(feature_dim, add_adapter)

    # Quantizer (optional)
    self.quantizer = quantizer

    # Loss function
    self.criterion = self._init_loss(loss_type, class_weights)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger

    # Embedding storage (per-epoch, first batch only)
    self.embedding_dir = "embeddings"
    os.makedirs(self.embedding_dir, exist_ok=True)
    self._first_val_batch_features = None

    # UMAP visualization storage
    self._val_backbone_embeddings = []
    self._val_quantized_embeddings = []
Functions
forward
forward(images)

Forward pass through backbone + optional quantizer + decoder.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

quant_loss

Quantization loss (0 if no quantizer)

original_features

Extracted features (before quantization)

quantized_features

Features after quantization (same as original if no quantizer)

Source code in embeddings_squeeze\models\lightning_module.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(self, images):
    """
    Forward pass through backbone + optional quantizer + decoder.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
        quant_loss: Quantization loss (0 if no quantizer)
        original_features: Extracted features (before quantization)
        quantized_features: Features after quantization (same as original if no quantizer)
    """
    # Extract features
    features = self.backbone.extract_features(images, detach=self.feature_adapter is not None)

    # Apply adapter if present
    if self.feature_adapter is not None:
        features = features + self.feature_adapter(features)

    # Store original features for embedding extraction
    original_features = features

    # Quantize if quantizer is present
    quant_loss = torch.tensor(0.0, device=images.device)
    quantized_features = original_features  # Default to original if no quantizer
    if self.quantizer is not None:
        features, quant_loss = self.quantizer.quantize_spatial(features)
        quantized_features = features

    # Decode to segmentation logits
    output = self.backbone.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear', align_corners=False)

    return output, quant_loss, original_features, quantized_features
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\lightning_module.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, _, _ = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\lightning_module.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, backbone_features, quantized_features = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    # Accumulate embeddings for UMAP visualization
    self._val_backbone_embeddings.append(backbone_features.detach().cpu())
    self._val_quantized_embeddings.append(quantized_features.detach().cpu())

    # Save only first batch features for this epoch
    if batch_idx == 0:
        self._first_val_batch_features = backbone_features.detach().cpu()

    return loss
on_validation_epoch_start
on_validation_epoch_start()

Clear accumulated embeddings at the start of each validation epoch.

Source code in embeddings_squeeze\models\lightning_module.py
270
271
272
273
def on_validation_epoch_start(self):
    """Clear accumulated embeddings at the start of each validation epoch."""
    self._val_backbone_embeddings.clear()
    self._val_quantized_embeddings.clear()
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations and save embeddings.

Source code in embeddings_squeeze\models\lightning_module.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations and save embeddings."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    keys = [
        "train/loss", "val/loss", "train/iou", "val/iou",
        "train/precision", "val/precision", "train/recall", "val/recall",
        "train/f1", "val/f1"
    ]
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )

    # Generate UMAP visualizations on even epochs
    if self.current_epoch % 2 == 0:
        try:
            import umap.umap_ as umap_module

            # Only proceed if we have embeddings
            if len(self._val_backbone_embeddings) > 0 and len(self._val_quantized_embeddings) > 0:
                # Concatenate all accumulated embeddings
                backbone_emb_flat = torch.cat(self._val_backbone_embeddings, dim=0)
                quantized_emb_flat = torch.cat(self._val_quantized_embeddings, dim=0)

                # Flatten spatial dimensions: [B, C, H, W] -> [B*H*W, C]
                backbone_emb_flat = backbone_emb_flat.permute(0, 2, 3, 1).reshape(-1, backbone_emb_flat.shape[1])
                quantized_emb_flat = quantized_emb_flat.permute(0, 2, 3, 1).reshape(-1, quantized_emb_flat.shape[1])

                # Convert to numpy
                backbone_emb_np = backbone_emb_flat.numpy()
                quantized_emb_np = quantized_emb_flat.numpy()

                # Limit samples for performance (take subset if too large)
                max_samples = 10000
                if len(backbone_emb_np) > max_samples:
                    indices = np.random.choice(len(backbone_emb_np), max_samples, replace=False)
                    backbone_emb_np = backbone_emb_np[indices]
                    quantized_emb_np = quantized_emb_np[indices]

                # Generate 2D UMAP
                fig_2d, axs_2d = plt.subplots(1, 2, figsize=(12, 6))

                proj_2d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(backbone_emb_np)
                axs_2d[0].scatter(proj_2d_backbone[:, 0], proj_2d_backbone[:, 1], alpha=0.3)
                axs_2d[0].set_title('2D UMAP: Backbone Embeddings')

                proj_2d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(quantized_emb_np)
                axs_2d[1].scatter(proj_2d_quantized[:, 0], proj_2d_quantized[:, 1], alpha=0.3)
                axs_2d[1].set_title('2D UMAP: Quantized Embeddings')

                # Convert 2D plot to image and log
                fig_2d.canvas.draw()
                img_2d = np.frombuffer(fig_2d.canvas.tostring_rgb(), dtype=np.uint8)
                img_2d = img_2d.reshape(fig_2d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_2d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"2d_embeddings_epoch_{self.current_epoch}", 
                        img_2d, 
                        iteration=self.current_epoch
                    )

                # Generate 3D UMAP
                fig_3d = plt.figure(figsize=(12, 6))
                ax1 = fig_3d.add_subplot(121, projection='3d')
                ax2 = fig_3d.add_subplot(122, projection='3d')

                proj_3d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(backbone_emb_np)
                ax1.scatter(proj_3d_backbone[:, 0], proj_3d_backbone[:, 1], proj_3d_backbone[:, 2], alpha=0.3)
                ax1.set_title('3D UMAP: Backbone Embeddings')

                proj_3d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(quantized_emb_np)
                ax2.scatter(proj_3d_quantized[:, 0], proj_3d_quantized[:, 1], proj_3d_quantized[:, 2], alpha=0.3)
                ax2.set_title('3D UMAP: Quantized Embeddings')

                # Convert 3D plot to image and log
                fig_3d.canvas.draw()
                img_3d = np.frombuffer(fig_3d.canvas.tostring_rgb(), dtype=np.uint8)
                img_3d = img_3d.reshape(fig_3d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_3d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"3d_embeddings_epoch_{self.current_epoch}", 
                        img_3d, 
                        iteration=self.current_epoch
                    )

            # Clear accumulated embeddings after logging
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

        except Exception as e:
            if self.clearml_logger:
                self.clearml_logger.report_text(
                    f"UMAP visualization failed at epoch {self.current_epoch}: {e}"
                )
            # Clear embeddings even if visualization failed
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

    # Save per-epoch embedding (first validation batch only)
    try:
        if self._first_val_batch_features is not None:
            emb_path = os.path.join(
                self.embedding_dir,
                f"val_embedding_epoch{self.current_epoch}.pt"
            )
            torch.save(self._first_val_batch_features, emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Saved small embedding: {emb_path}")
            # Reset for next epoch
            self._first_val_batch_features = None
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed saving epoch embedding: {e}")
configure_optimizers
configure_optimizers()

Configure optimizer - only trainable parameters.

Source code in embeddings_squeeze\models\lightning_module.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def configure_optimizers(self):
    """Configure optimizer - only trainable parameters."""
    params = []

    # Add adapter parameters if present
    if self.feature_adapter is not None:
        params += list(self.feature_adapter.parameters())

    # Add quantizer parameters if present
    if self.quantizer is not None:
        params += list(self.quantizer.parameters())

    # Add backbone parameters if not frozen
    if self.feature_adapter is None:
        params += [p for p in self.backbone.parameters() if p.requires_grad]

    # Remove duplicates
    params = list({id(p): p for p in params}.values())

    if not params:
        raise ValueError("No trainable parameters found!")

    return torch.optim.AdamW(params, lr=self.learning_rate)
on_train_start
on_train_start()

Ensure frozen backbone stays in eval mode.

Source code in embeddings_squeeze\models\lightning_module.py
494
495
496
497
def on_train_start(self):
    """Ensure frozen backbone stays in eval mode."""
    if self.feature_adapter is not None:
        self.backbone.eval()
BaselineSegmentationModule
BaselineSegmentationModule(
    backbone,
    num_classes=21,
    learning_rate=0.0001,
    loss_type="ce",
    class_weights=None,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for baseline segmentation training.

Wraps segmentation backbone without Vector Quantization for comparison.

Source code in embeddings_squeeze\models\baseline_module.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    backbone: SegmentationBackbone,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'clearml_logger'])

    self.backbone = backbone
    self.num_classes = num_classes
    self.learning_rate = learning_rate

    # Segmentation loss
    if class_weights is not None:
        weight = torch.tensor(class_weights, dtype=torch.float32)
        self.seg_criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
    else:
        self.seg_criterion = nn.CrossEntropyLoss(ignore_index=255)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger
Functions
forward
forward(images)

Forward pass through backbone.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, images):
    """
    Forward pass through backbone.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    output = self.backbone(images)
    # Handle both dict and tensor returns
    if isinstance(output, dict):
        return output['out']
    return output
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\baseline_module.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', seg_loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\baseline_module.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations.

Source code in embeddings_squeeze\models\baseline_module.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )
configure_optimizers
configure_optimizers()

Configure optimizer.

Source code in embeddings_squeeze\models\baseline_module.py
228
229
230
231
232
def configure_optimizers(self):
    """Configure optimizer."""
    # Optimize only trainable params
    params = [p for p in self.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=self.learning_rate)
predict
predict(images)

Predict segmentation masks.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
predictions

Segmentation predictions [B, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def predict(self, images):
    """
    Predict segmentation masks.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        predictions: Segmentation predictions [B, H, W]
    """
    self.eval()
    with torch.no_grad():
        output = self(images)
        predictions = output.argmax(dim=1)
    return predictions
predict_logits
predict_logits(images)

Predict segmentation logits.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
logits

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def predict_logits(self, images):
    """
    Predict segmentation logits.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        logits: Segmentation logits [B, num_classes, H, W]
    """
    self.eval()
    with torch.no_grad():
        logits = self(images)
    return logits

Modules

backbones

Segmentation backbone implementations.

Classes
SegmentationBackbone
SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()
Attributes
feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions
extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass
ViTSegmentationBackbone
ViTSegmentationBackbone(
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

ViT-based segmentation backbone.

Uses ViT-B/32 as backbone with custom segmentation head.

Source code in embeddings_squeeze\models\backbones\vit.py
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()
    base_vit = model_fn(weights=weights)
    self.backbone = _ViTBackboneWrapper(base_vit, freeze=freeze_backbone)
    self.classifier = _ViTSegmentationHead(self.backbone.hidden_dim, num_classes)

    self._num_classes = num_classes
Attributes
feature_dim property
feature_dim

Return ViT hidden dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract ViT backbone feature maps.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, hidden_dim, H/patch, W/patch]

Source code in embeddings_squeeze\models\backbones\vit.py
80
81
82
83
84
85
86
87
88
89
90
91
92
def extract_features(self, images, detach=True):
    """
    Extract ViT backbone feature maps.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, hidden_dim, H/patch, W/patch]
    """
    feats = self.backbone(images)['out']
    return feats.detach() if detach else feats
forward
forward(images)

Full ViT segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\vit.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, images):
    """
    Full ViT segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    logits = self.classifier(features)
    logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
    return {'out': logits}
DeepLabV3SegmentationBackbone
DeepLabV3SegmentationBackbone(
    weights_name="COCO_WITH_VOC_LABELS_V1",
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

DeepLabV3-ResNet50 segmentation backbone.

Uses pre-trained DeepLabV3-ResNet50 for segmentation.

Source code in embeddings_squeeze\models\backbones\deeplab.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    weights_name='COCO_WITH_VOC_LABELS_V1',
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()

    weights = getattr(DeepLabV3_ResNet50_Weights, weights_name)
    model = deeplabv3_resnet50(weights=weights)

    self.backbone = model.backbone
    self.classifier = model.classifier

    self._num_classes = num_classes

    if freeze_backbone:
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
Attributes
feature_dim property
feature_dim

Return DeepLabV3 feature dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract DeepLabV3 backbone features.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, 2048, H/8, W/8]

Source code in embeddings_squeeze\models\backbones\deeplab.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def extract_features(self, images, detach=True):
    """
    Extract DeepLabV3 backbone features.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, 2048, H/8, W/8]
    """
    with torch.set_grad_enabled(not detach):
        features = self.backbone(images)['out']
    return features
forward
forward(images)

Full DeepLabV3 segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\deeplab.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def forward(self, images):
    """
    Full DeepLabV3 segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    output = self.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear')
    return {'out': output}
Modules
base

Abstract base class for segmentation backbones.

Classes
SegmentationBackbone
SegmentationBackbone()

Bases: Module, ABC

Abstract base class for segmentation backbones.

All segmentation backbones should inherit from this class and implement the required methods for feature extraction and full segmentation.

Source code in embeddings_squeeze\models\backbones\base.py
17
18
def __init__(self):
    super().__init__()
Attributes
feature_dim abstractmethod property
feature_dim

Return the feature dimension.

num_classes abstractmethod property
num_classes

Return the number of output classes.

Functions
extract_features abstractmethod
extract_features(images, detach=True)

Extract features from input images.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, feature_dim, H', W']

Source code in embeddings_squeeze\models\backbones\base.py
20
21
22
23
24
25
26
27
28
29
30
31
32
@abstractmethod
def extract_features(self, images, detach=True):
    """
    Extract features from input images.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, feature_dim, H', W']
    """
    pass
forward abstractmethod
forward(images)

Full forward pass for segmentation.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\base.py
34
35
36
37
38
39
40
41
42
43
44
45
@abstractmethod
def forward(self, images):
    """
    Full forward pass for segmentation.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    pass
deeplab

DeepLabV3-ResNet50 segmentation backbone implementation.

Classes
DeepLabV3SegmentationBackbone
DeepLabV3SegmentationBackbone(
    weights_name="COCO_WITH_VOC_LABELS_V1",
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

DeepLabV3-ResNet50 segmentation backbone.

Uses pre-trained DeepLabV3-ResNet50 for segmentation.

Source code in embeddings_squeeze\models\backbones\deeplab.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    weights_name='COCO_WITH_VOC_LABELS_V1',
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()

    weights = getattr(DeepLabV3_ResNet50_Weights, weights_name)
    model = deeplabv3_resnet50(weights=weights)

    self.backbone = model.backbone
    self.classifier = model.classifier

    self._num_classes = num_classes

    if freeze_backbone:
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
Attributes
feature_dim property
feature_dim

Return DeepLabV3 feature dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract DeepLabV3 backbone features.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, 2048, H/8, W/8]

Source code in embeddings_squeeze\models\backbones\deeplab.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def extract_features(self, images, detach=True):
    """
    Extract DeepLabV3 backbone features.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, 2048, H/8, W/8]
    """
    with torch.set_grad_enabled(not detach):
        features = self.backbone(images)['out']
    return features
forward
forward(images)

Full DeepLabV3 segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\deeplab.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def forward(self, images):
    """
    Full DeepLabV3 segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    output = self.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear')
    return {'out': output}
vit

ViT-based segmentation backbone implementation.

Classes
ViTSegmentationBackbone
ViTSegmentationBackbone(
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone=True,
)

Bases: SegmentationBackbone

ViT-based segmentation backbone.

Uses ViT-B/32 as backbone with custom segmentation head.

Source code in embeddings_squeeze\models\backbones\vit.py
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    model_fn=vit_b_32,
    weights=ViT_B_32_Weights.IMAGENET1K_V1,
    num_classes=21,
    freeze_backbone: bool = True,
):
    super().__init__()
    base_vit = model_fn(weights=weights)
    self.backbone = _ViTBackboneWrapper(base_vit, freeze=freeze_backbone)
    self.classifier = _ViTSegmentationHead(self.backbone.hidden_dim, num_classes)

    self._num_classes = num_classes
Attributes
feature_dim property
feature_dim

Return ViT hidden dimension.

num_classes property
num_classes

Return number of segmentation classes.

Functions
extract_features
extract_features(images, detach=True)

Extract ViT backbone feature maps.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required
detach

Whether to detach gradients from backbone

True

Returns:

Name Type Description
features

Feature maps [B, hidden_dim, H/patch, W/patch]

Source code in embeddings_squeeze\models\backbones\vit.py
80
81
82
83
84
85
86
87
88
89
90
91
92
def extract_features(self, images, detach=True):
    """
    Extract ViT backbone feature maps.

    Args:
        images: Input images [B, C, H, W]
        detach: Whether to detach gradients from backbone

    Returns:
        features: Feature maps [B, hidden_dim, H/patch, W/patch]
    """
    feats = self.backbone(images)['out']
    return feats.detach() if detach else feats
forward
forward(images)

Full ViT segmentation forward pass.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\backbones\vit.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, images):
    """
    Full ViT segmentation forward pass.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    features = self.backbone(images)['out']
    logits = self.classifier(features)
    logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
    return {'out': logits}
baseline_module

PyTorch Lightning module for baseline segmentation training without VQ.

Classes
BaselineSegmentationModule
BaselineSegmentationModule(
    backbone,
    num_classes=21,
    learning_rate=0.0001,
    loss_type="ce",
    class_weights=None,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for baseline segmentation training.

Wraps segmentation backbone without Vector Quantization for comparison.

Source code in embeddings_squeeze\models\baseline_module.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    backbone: SegmentationBackbone,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'clearml_logger'])

    self.backbone = backbone
    self.num_classes = num_classes
    self.learning_rate = learning_rate

    # Segmentation loss
    if class_weights is not None:
        weight = torch.tensor(class_weights, dtype=torch.float32)
        self.seg_criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255)
    else:
        self.seg_criterion = nn.CrossEntropyLoss(ignore_index=255)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger
Functions
forward
forward(images)

Forward pass through backbone.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, images):
    """
    Forward pass through backbone.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
    """
    output = self.backbone(images)
    # Handle both dict and tensor returns
    if isinstance(output, dict):
        return output['out']
    return output
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\baseline_module.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', seg_loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\baseline_module.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch
    masks = masks.squeeze(1).long()

    # Forward pass
    output = self(images)

    # Compute loss
    seg_loss = self.seg_criterion(output, masks)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', seg_loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    return seg_loss
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations.

Source code in embeddings_squeeze\models\baseline_module.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )
configure_optimizers
configure_optimizers()

Configure optimizer.

Source code in embeddings_squeeze\models\baseline_module.py
228
229
230
231
232
def configure_optimizers(self):
    """Configure optimizer."""
    # Optimize only trainable params
    params = [p for p in self.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=self.learning_rate)
predict
predict(images)

Predict segmentation masks.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
predictions

Segmentation predictions [B, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def predict(self, images):
    """
    Predict segmentation masks.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        predictions: Segmentation predictions [B, H, W]
    """
    self.eval()
    with torch.no_grad():
        output = self(images)
        predictions = output.argmax(dim=1)
    return predictions
predict_logits
predict_logits(images)

Predict segmentation logits.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
logits

Segmentation logits [B, num_classes, H, W]

Source code in embeddings_squeeze\models\baseline_module.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def predict_logits(self, images):
    """
    Predict segmentation logits.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        logits: Segmentation logits [B, num_classes, H, W]
    """
    self.eval()
    with torch.no_grad():
        logits = self(images)
    return logits
lightning_module

PyTorch Lightning module for VQ compression training with advanced features. Supports multiple quantizers (VQ, FSQ, LFQ, RVQ), adapters, and loss functions.

Classes
VQSqueezeModule
VQSqueezeModule(
    backbone,
    quantizer=None,
    num_classes=21,
    learning_rate=0.0001,
    vq_loss_weight=0.1,
    loss_type="ce",
    class_weights=None,
    add_adapter=False,
    feature_dim=2048,
    clearml_logger=None,
    **kwargs
)

Bases: LightningModule

PyTorch Lightning module for VQ compression training.

Features: - Multiple quantizer support (VQ, FSQ, LFQ, RVQ) - Adapter layers for fine-tuning frozen backbones - Advanced loss functions (CE, Dice, Focal, Combined) - Embedding extraction and saving

Source code in embeddings_squeeze\models\lightning_module.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    backbone: SegmentationBackbone,
    quantizer: Optional[nn.Module] = None,
    num_classes: int = 21,
    learning_rate: float = 1e-4,
    vq_loss_weight: float = 0.1,
    loss_type: str = 'ce',
    class_weights: Optional[list] = None,
    add_adapter: bool = False,
    feature_dim: int = 2048,
    clearml_logger: Optional[Any] = None,
    **kwargs
):
    super().__init__()
    self.save_hyperparameters(ignore=['backbone', 'quantizer', 'clearml_logger'])

    self.num_classes = num_classes
    self.learning_rate = learning_rate
    self.vq_loss_weight = vq_loss_weight
    self.loss_type = loss_type
    self.add_adapter = add_adapter
    self.feature_dim = feature_dim

    # Setup backbone with optional adapters
    self.backbone = backbone
    self._setup_backbone_with_adapters(feature_dim, add_adapter)

    # Quantizer (optional)
    self.quantizer = quantizer

    # Loss function
    self.criterion = self._init_loss(loss_type, class_weights)

    # Metrics
    self.train_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.val_iou = JaccardIndex(task="multiclass", num_classes=num_classes)
    self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
    self.train_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.val_prec = Precision(task="multiclass", num_classes=num_classes, average="macro")
    self.train_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.val_rec = Recall(task="multiclass", num_classes=num_classes, average="macro")
    self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")

    # Epoch-wise stats tracking for Plotly
    self.epoch_stats: Dict[str, list] = {
        "train_loss": [], "val_loss": [], 
        "train_iou": [], "val_iou": [],
        "train_precision": [], "val_precision": [], 
        "train_recall": [], "val_recall": [],
        "train_f1": [], "val_f1": []
    }

    # ClearML logger
    self.clearml_logger = clearml_logger

    # Embedding storage (per-epoch, first batch only)
    self.embedding_dir = "embeddings"
    os.makedirs(self.embedding_dir, exist_ok=True)
    self._first_val_batch_features = None

    # UMAP visualization storage
    self._val_backbone_embeddings = []
    self._val_quantized_embeddings = []
Functions
forward
forward(images)

Forward pass through backbone + optional quantizer + decoder.

Parameters:

Name Type Description Default
images

Input images [B, C, H, W]

required

Returns:

Name Type Description
output

Segmentation logits [B, num_classes, H, W]

quant_loss

Quantization loss (0 if no quantizer)

original_features

Extracted features (before quantization)

quantized_features

Features after quantization (same as original if no quantizer)

Source code in embeddings_squeeze\models\lightning_module.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def forward(self, images):
    """
    Forward pass through backbone + optional quantizer + decoder.

    Args:
        images: Input images [B, C, H, W]

    Returns:
        output: Segmentation logits [B, num_classes, H, W]
        quant_loss: Quantization loss (0 if no quantizer)
        original_features: Extracted features (before quantization)
        quantized_features: Features after quantization (same as original if no quantizer)
    """
    # Extract features
    features = self.backbone.extract_features(images, detach=self.feature_adapter is not None)

    # Apply adapter if present
    if self.feature_adapter is not None:
        features = features + self.feature_adapter(features)

    # Store original features for embedding extraction
    original_features = features

    # Quantize if quantizer is present
    quant_loss = torch.tensor(0.0, device=images.device)
    quantized_features = original_features  # Default to original if no quantizer
    if self.quantizer is not None:
        features, quant_loss = self.quantizer.quantize_spatial(features)
        quantized_features = features

    # Decode to segmentation logits
    output = self.backbone.classifier(features)
    output = F.interpolate(output, size=images.shape[-2:], mode='bilinear', align_corners=False)

    return output, quant_loss, original_features, quantized_features
training_step
training_step(batch, batch_idx)

Training step.

Source code in embeddings_squeeze\models\lightning_module.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def training_step(self, batch, batch_idx):
    """Training step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, _, _ = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.train_iou(output, masks)
    acc = self.train_acc(output, masks)
    prec = self.train_prec(output, masks)
    rec = self.train_rec(output, masks)
    f1 = self.train_f1(output, masks)

    # Log metrics
    self.log('train_step/loss', loss, on_step=True, on_epoch=False, prog_bar=False)

    self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('train/precision', prec, on_step=False, on_epoch=True)
    self.log('train/recall', rec, on_step=False, on_epoch=True)
    self.log('train/f1', f1, on_step=False, on_epoch=True)

    return loss
validation_step
validation_step(batch, batch_idx)

Validation step.

Source code in embeddings_squeeze\models\lightning_module.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def validation_step(self, batch, batch_idx):
    """Validation step."""
    images, masks = batch

    # Handle mask dimensions
    if masks.dim() == 4:
        masks = masks.squeeze(1)
    masks = masks.long()

    # Forward pass
    output, quant_loss, backbone_features, quantized_features = self(images)

    # Compute loss
    loss = self._compute_loss(output, masks, quant_loss)

    # Compute metrics
    iou = self.val_iou(output, masks)
    acc = self.val_acc(output, masks)
    prec = self.val_prec(output, masks)
    rec = self.val_rec(output, masks)
    f1 = self.val_f1(output, masks)

    # Log metrics
    self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/iou', iou, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    self.log('val/precision', prec, on_step=False, on_epoch=True)
    self.log('val/recall', rec, on_step=False, on_epoch=True)
    self.log('val/f1', f1, on_step=False, on_epoch=True)

    # Accumulate embeddings for UMAP visualization
    self._val_backbone_embeddings.append(backbone_features.detach().cpu())
    self._val_quantized_embeddings.append(quantized_features.detach().cpu())

    # Save only first batch features for this epoch
    if batch_idx == 0:
        self._first_val_batch_features = backbone_features.detach().cpu()

    return loss
on_validation_epoch_start
on_validation_epoch_start()

Clear accumulated embeddings at the start of each validation epoch.

Source code in embeddings_squeeze\models\lightning_module.py
270
271
272
273
def on_validation_epoch_start(self):
    """Clear accumulated embeddings at the start of each validation epoch."""
    self._val_backbone_embeddings.clear()
    self._val_quantized_embeddings.clear()
on_validation_epoch_end
on_validation_epoch_end()

Called after validation epoch ends - log Plotly visualizations and save embeddings.

Source code in embeddings_squeeze\models\lightning_module.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def on_validation_epoch_end(self):
    """Called after validation epoch ends - log Plotly visualizations and save embeddings."""
    # Collect epoch stats from trainer callback metrics
    cm = self.trainer.callback_metrics

    def push_if_exists(k_from, k_to):
        """Helper to extract metrics from callback_metrics."""
        if k_from in cm:
            val = cm[k_from]
            try:
                v = float(val)
            except Exception:
                v = val.item()
            self.epoch_stats[k_to].append(v)

    # Push metrics to epoch_stats
    keys = [
        "train/loss", "val/loss", "train/iou", "val/iou",
        "train/precision", "val/precision", "train/recall", "val/recall",
        "train/f1", "val/f1"
    ]
    key_mapping = {
        "train/loss": "train_loss", "val/loss": "val_loss",
        "train/iou": "train_iou", "val/iou": "val_iou",
        "train/precision": "train_precision", "val/precision": "val_precision",
        "train/recall": "train_recall", "val/recall": "val_recall",
        "train/f1": "train_f1", "val/f1": "val_f1"
    }
    for k_from, k_to in key_mapping.items():
        push_if_exists(k_from, k_to)

    # Generate Plotly visualizations
    try:
        import plotly.graph_objects as go

        epoch = self.current_epoch
        epochs = list(range(len(self.epoch_stats["val_loss"])))

        # Loss plot
        fig_loss = go.Figure()
        if len(self.epoch_stats["train_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["train_loss"],
                mode="lines+markers", name="train_loss"
            ))
        if len(self.epoch_stats["val_loss"]) > 0:
            fig_loss.add_trace(go.Scatter(
                x=epochs, y=self.epoch_stats["val_loss"],
                mode="lines+markers", name="val_loss"
            ))
        fig_loss.update_layout(title="Loss", xaxis_title="epoch", yaxis_title="loss")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Loss", series="loss", iteration=epoch, figure=fig_loss
            )

        # Metrics plot
        fig_m = go.Figure()
        metrics_to_plot = [
            ("train_iou", "val_iou"),
            ("train_precision", "val_precision"),
            ("train_recall", "val_recall"),
            ("train_f1", "val_f1")
        ]
        for train_k, val_k in metrics_to_plot:
            if len(self.epoch_stats[train_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[train_k],
                    mode="lines+markers", name=train_k
                ))
            if len(self.epoch_stats[val_k]) > 0:
                fig_m.add_trace(go.Scatter(
                    x=epochs, y=self.epoch_stats[val_k],
                    mode="lines+markers", name=val_k
                ))
        fig_m.update_layout(title="Metrics", xaxis_title="epoch", yaxis_title="value")

        if self.clearml_logger:
            self.clearml_logger.report_plotly(
                title="Metrics", series="metrics", iteration=epoch, figure=fig_m
            )
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(
                f"Plotly reporting failed at epoch {self.current_epoch}: {e}"
            )

    # Generate UMAP visualizations on even epochs
    if self.current_epoch % 2 == 0:
        try:
            import umap.umap_ as umap_module

            # Only proceed if we have embeddings
            if len(self._val_backbone_embeddings) > 0 and len(self._val_quantized_embeddings) > 0:
                # Concatenate all accumulated embeddings
                backbone_emb_flat = torch.cat(self._val_backbone_embeddings, dim=0)
                quantized_emb_flat = torch.cat(self._val_quantized_embeddings, dim=0)

                # Flatten spatial dimensions: [B, C, H, W] -> [B*H*W, C]
                backbone_emb_flat = backbone_emb_flat.permute(0, 2, 3, 1).reshape(-1, backbone_emb_flat.shape[1])
                quantized_emb_flat = quantized_emb_flat.permute(0, 2, 3, 1).reshape(-1, quantized_emb_flat.shape[1])

                # Convert to numpy
                backbone_emb_np = backbone_emb_flat.numpy()
                quantized_emb_np = quantized_emb_flat.numpy()

                # Limit samples for performance (take subset if too large)
                max_samples = 10000
                if len(backbone_emb_np) > max_samples:
                    indices = np.random.choice(len(backbone_emb_np), max_samples, replace=False)
                    backbone_emb_np = backbone_emb_np[indices]
                    quantized_emb_np = quantized_emb_np[indices]

                # Generate 2D UMAP
                fig_2d, axs_2d = plt.subplots(1, 2, figsize=(12, 6))

                proj_2d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(backbone_emb_np)
                axs_2d[0].scatter(proj_2d_backbone[:, 0], proj_2d_backbone[:, 1], alpha=0.3)
                axs_2d[0].set_title('2D UMAP: Backbone Embeddings')

                proj_2d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine').fit_transform(quantized_emb_np)
                axs_2d[1].scatter(proj_2d_quantized[:, 0], proj_2d_quantized[:, 1], alpha=0.3)
                axs_2d[1].set_title('2D UMAP: Quantized Embeddings')

                # Convert 2D plot to image and log
                fig_2d.canvas.draw()
                img_2d = np.frombuffer(fig_2d.canvas.tostring_rgb(), dtype=np.uint8)
                img_2d = img_2d.reshape(fig_2d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_2d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"2d_embeddings_epoch_{self.current_epoch}", 
                        img_2d, 
                        iteration=self.current_epoch
                    )

                # Generate 3D UMAP
                fig_3d = plt.figure(figsize=(12, 6))
                ax1 = fig_3d.add_subplot(121, projection='3d')
                ax2 = fig_3d.add_subplot(122, projection='3d')

                proj_3d_backbone = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(backbone_emb_np)
                ax1.scatter(proj_3d_backbone[:, 0], proj_3d_backbone[:, 1], proj_3d_backbone[:, 2], alpha=0.3)
                ax1.set_title('3D UMAP: Backbone Embeddings')

                proj_3d_quantized = umap_module.UMAP(n_neighbors=3, min_dist=0.1, metric='cosine', n_components=3).fit_transform(quantized_emb_np)
                ax2.scatter(proj_3d_quantized[:, 0], proj_3d_quantized[:, 1], proj_3d_quantized[:, 2], alpha=0.3)
                ax2.set_title('3D UMAP: Quantized Embeddings')

                # Convert 3D plot to image and log
                fig_3d.canvas.draw()
                img_3d = np.frombuffer(fig_3d.canvas.tostring_rgb(), dtype=np.uint8)
                img_3d = img_3d.reshape(fig_3d.canvas.get_width_height()[::-1] + (3,))
                plt.close(fig_3d)

                if self.clearml_logger:
                    self.clearml_logger.log_image(
                        "umap_visualizations", 
                        f"3d_embeddings_epoch_{self.current_epoch}", 
                        img_3d, 
                        iteration=self.current_epoch
                    )

            # Clear accumulated embeddings after logging
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

        except Exception as e:
            if self.clearml_logger:
                self.clearml_logger.report_text(
                    f"UMAP visualization failed at epoch {self.current_epoch}: {e}"
                )
            # Clear embeddings even if visualization failed
            self._val_backbone_embeddings.clear()
            self._val_quantized_embeddings.clear()

    # Save per-epoch embedding (first validation batch only)
    try:
        if self._first_val_batch_features is not None:
            emb_path = os.path.join(
                self.embedding_dir,
                f"val_embedding_epoch{self.current_epoch}.pt"
            )
            torch.save(self._first_val_batch_features, emb_path)
            if self.clearml_logger:
                self.clearml_logger.report_text(f"Saved small embedding: {emb_path}")
            # Reset for next epoch
            self._first_val_batch_features = None
    except Exception as e:
        if self.clearml_logger:
            self.clearml_logger.report_text(f"Failed saving epoch embedding: {e}")
configure_optimizers
configure_optimizers()

Configure optimizer - only trainable parameters.

Source code in embeddings_squeeze\models\lightning_module.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def configure_optimizers(self):
    """Configure optimizer - only trainable parameters."""
    params = []

    # Add adapter parameters if present
    if self.feature_adapter is not None:
        params += list(self.feature_adapter.parameters())

    # Add quantizer parameters if present
    if self.quantizer is not None:
        params += list(self.quantizer.parameters())

    # Add backbone parameters if not frozen
    if self.feature_adapter is None:
        params += [p for p in self.backbone.parameters() if p.requires_grad]

    # Remove duplicates
    params = list({id(p): p for p in params}.values())

    if not params:
        raise ValueError("No trainable parameters found!")

    return torch.optim.AdamW(params, lr=self.learning_rate)
on_train_start
on_train_start()

Ensure frozen backbone stays in eval mode.

Source code in embeddings_squeeze\models\lightning_module.py
494
495
496
497
def on_train_start(self):
    """Ensure frozen backbone stays in eval mode."""
    if self.feature_adapter is not None:
        self.backbone.eval()
losses

Loss functions for segmentation tasks. Includes: Cross Entropy, Dice Loss, Focal Loss, and Combined Loss.

Classes
DiceLoss
DiceLoss(smooth=1.0)

Bases: Module

Dice Loss for multi-class segmentation

Source code in embeddings_squeeze\models\losses.py
13
14
15
def __init__(self, smooth: float = 1.0):
    super().__init__()
    self.smooth = smooth
FocalLoss
FocalLoss(alpha=1.0, gamma=2.0, reduction='mean')

Bases: Module

Focal Loss for handling class imbalance (multi-class via CE per-pixel)

Source code in embeddings_squeeze\models\losses.py
40
41
42
43
44
def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction
CombinedLoss
CombinedLoss(
    ce_weight=1.0,
    dice_weight=1.0,
    focal_weight=0.5,
    class_weights=None,
)

Bases: Module

Combined loss: CE + Dice + Focal. Returns (total, ce, dice, focal).

Source code in embeddings_squeeze\models\losses.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, 
    ce_weight: float = 1.0, 
    dice_weight: float = 1.0, 
    focal_weight: float = 0.5, 
    class_weights=None
):
    super().__init__()
    # class_weights can be None or a tensor/list
    if class_weights is not None:
        # Leave tensor creation to forward (to place on correct device) but store raw
        self._class_weights = class_weights
    else:
        self._class_weights = None

    self.ce_weight = ce_weight
    self.dice_weight = dice_weight
    self.focal_weight = focal_weight

    # Instantiate component losses
    self.dice_loss = DiceLoss()
    self.focal_loss = FocalLoss()
quantizers

Vector Quantization implementations using vector_quantize_pytorch library. Supports: VQ-VAE, FSQ, LFQ, and Residual VQ.

Classes
BaseQuantizer
BaseQuantizer(input_dim)

Bases: Module

Base class for all quantizers

Source code in embeddings_squeeze\models\quantizers.py
14
15
16
def __init__(self, input_dim: int):
    super().__init__()
    self.input_dim = input_dim
Functions
quantize_spatial
quantize_spatial(features)

Quantize spatial features [B, C, H, W]

Parameters:

Name Type Description Default
features Tensor

Tensor of shape [B, C, H, W]

required

Returns:

Name Type Description
quantized

Quantized features [B, C, H, W]

loss

Quantization loss (scalar)

Source code in embeddings_squeeze\models\quantizers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def quantize_spatial(self, features: torch.Tensor):
    """
    Quantize spatial features [B, C, H, W]

    Args:
        features: Tensor of shape [B, C, H, W]

    Returns:
        quantized: Quantized features [B, C, H, W]
        loss: Quantization loss (scalar)
    """
    B, C, H, W = features.shape
    # Transform [B, C, H, W] -> [B, H*W, C]
    seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)

    # Quantize
    quantized, indices, loss = self.forward(seq)

    # Transform back [B, H*W, C] -> [B, C, H, W]
    quantized = quantized.reshape(B, H, W, C).permute(0, 3, 1, 2)

    # Handle loss (may be tensor with multiple elements)
    if isinstance(loss, torch.Tensor) and loss.numel() > 1:
        loss = loss.mean()

    return quantized, loss
VQWithProjection
VQWithProjection(
    input_dim,
    codebook_size=512,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Vector Quantization (VQ-VAE) with projections

Uses EMA for codebook updates (no gradients needed for codebook) ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection (e.g., 2048 -> 64)
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Vector Quantization
    self.vq = VectorQuantize(
        dim=bottleneck_dim,
        codebook_size=codebook_size,
        decay=decay,  # EMA decay for codebook
        commitment_weight=commitment_weight  # Commitment loss weight
    )

    # Up projection (64 -> 2048)
    self.project_out = nn.Linear(bottleneck_dim, input_dim)
FSQWithProjection
FSQWithProjection(input_dim, levels=None)

Bases: BaseQuantizer

Finite Scalar Quantization (FSQ)

Quantization without codebook - each dimension quantized independently ~10 bits per vector at levels=[8,5,5,5]

Source code in embeddings_squeeze\models\quantizers.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(self, input_dim: int, levels: list = None):
    super().__init__(input_dim)
    if levels is None:
        levels = [8, 5, 5, 5]  # 8*5*5*5 = 1000 codes ≈ 2^10

    self.num_levels = len(levels)

    # Projection to quantization space
    self.project_in = nn.Linear(input_dim, self.num_levels)

    # FSQ quantization
    self.fsq = FSQ(levels=levels, dim=self.num_levels)

    # Projection back
    self.project_out = nn.Linear(self.num_levels, input_dim)
LFQWithProjection
LFQWithProjection(
    input_dim,
    codebook_size=512,
    entropy_loss_weight=0.1,
    diversity_gamma=0.1,
    spherical=False,
)

Bases: BaseQuantizer

Lookup-Free Quantization (LFQ)

Uses entropy loss for code diversity ~9 bits per vector at codebook_size=512

Source code in embeddings_squeeze\models\quantizers.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self, 
    input_dim: int, 
    codebook_size: int = 512,
    entropy_loss_weight: float = 0.1, 
    diversity_gamma: float = 0.1, 
    spherical: bool = False
):
    super().__init__(input_dim)
    # Quantization dimension = log2(codebook_size)
    self.quant_dim = int(math.log2(codebook_size))

    # Projection with normalization
    self.project_in = nn.Sequential(
        nn.Linear(input_dim, self.quant_dim),
        nn.LayerNorm(self.quant_dim)
    )

    # LFQ quantization
    self.lfq = LFQ(
        dim=self.quant_dim,
        codebook_size=codebook_size,
        entropy_loss_weight=entropy_loss_weight,
        diversity_gamma=diversity_gamma,
        spherical=spherical
    )

    # Projection back
    self.project_out = nn.Linear(self.quant_dim, input_dim)
ResidualVQWithProjection
ResidualVQWithProjection(
    input_dim,
    num_quantizers=4,
    codebook_size=256,
    bottleneck_dim=64,
    decay=0.99,
    commitment_weight=0.25,
)

Bases: BaseQuantizer

Residual Vector Quantization (RVQ)

Multi-level quantization - each level quantizes the residual of the previous 32 bits per vector at num_quantizers=4, codebook_size=256 (4*8 bits)

Source code in embeddings_squeeze\models\quantizers.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self, 
    input_dim: int, 
    num_quantizers: int = 4,
    codebook_size: int = 256, 
    bottleneck_dim: int = 64,
    decay: float = 0.99, 
    commitment_weight: float = 0.25
):
    super().__init__(input_dim)
    self.bottleneck_dim = bottleneck_dim

    # Down projection
    self.project_in = nn.Linear(input_dim, bottleneck_dim)

    # Residual VQ
    self.residual_vq = ResidualVQ(
        dim=bottleneck_dim,
        num_quantizers=num_quantizers,  # Number of levels
        codebook_size=codebook_size,
        decay=decay,
        commitment_weight=commitment_weight
    )

    # Up projection
    self.project_out = nn.Linear(bottleneck_dim, input_dim)

squeeze

CLI script for VQ compression training.

Usage

python squeeze.py --model vit --dataset oxford_pet --num_vectors 128 --epochs 3

Functions

create_quantizer
create_quantizer(config)

Create quantizer based on config.

Source code in embeddings_squeeze\squeeze.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def create_quantizer(config):
    """Create quantizer based on config."""
    if not config.quantizer.enabled:
        return None

    qtype = config.quantizer.type.lower()
    feature_dim = config.model.feature_dim

    if qtype == 'vq':
        return VQWithProjection(
            input_dim=feature_dim,
            codebook_size=config.quantizer.codebook_size,
            bottleneck_dim=config.quantizer.bottleneck_dim,
            decay=config.quantizer.decay,
            commitment_weight=config.quantizer.commitment_weight
        )
    elif qtype == 'fsq':
        return FSQWithProjection(
            input_dim=feature_dim,
            levels=config.quantizer.levels
        )
    elif qtype == 'lfq':
        return LFQWithProjection(
            input_dim=feature_dim,
            codebook_size=config.quantizer.codebook_size,
            entropy_loss_weight=config.quantizer.entropy_loss_weight,
            diversity_gamma=config.quantizer.diversity_gamma,
            spherical=config.quantizer.spherical
        )
    elif qtype == 'rvq':
        return ResidualVQWithProjection(
            input_dim=feature_dim,
            num_quantizers=config.quantizer.num_quantizers,
            codebook_size=config.quantizer.codebook_size,
            bottleneck_dim=config.quantizer.bottleneck_dim,
            decay=config.quantizer.decay,
            commitment_weight=config.quantizer.commitment_weight
        )
    elif qtype == 'none':
        return None
    else:
        raise ValueError(f"Unknown quantizer type: {qtype}")
create_backbone
create_backbone(config)

Create segmentation backbone based on config.

Source code in embeddings_squeeze\squeeze.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def create_backbone(config):
    """Create segmentation backbone based on config."""
    # Auto-detect feature_dim based on backbone if not set or invalid
    if config.model.backbone.lower() == "vit":
        # ViT uses 768-dim features
        if config.model.feature_dim is None or config.model.feature_dim == 2048:
            config.model.feature_dim = 768
        backbone = ViTSegmentationBackbone(
            num_classes=config.model.num_classes,
            freeze_backbone=config.model.freeze_backbone
        )
    elif config.model.backbone.lower() == "deeplab":
        # DeepLab uses 2048-dim features
        if config.model.feature_dim is None:
            config.model.feature_dim = 2048
        backbone = DeepLabV3SegmentationBackbone(
            weights_name=config.model.deeplab_weights,
            num_classes=config.model.num_classes,
            freeze_backbone=config.model.freeze_backbone
        )
    else:
        raise ValueError(f"Unknown backbone: {config.model.backbone}")

    return backbone
create_data_module
create_data_module(config)

Create data module based on config.

Source code in embeddings_squeeze\squeeze.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def create_data_module(config):
    """Create data module based on config."""
    if config.data.dataset.lower() == "oxford_pet":
        data_module = OxfordPetDataModule(
            data_path=config.data.data_path,
            batch_size=config.training.batch_size,
            num_workers=config.training.num_workers,
            pin_memory=config.training.pin_memory,
            image_size=config.data.image_size,
            subset_size=config.data.subset_size
        )
    else:
        raise ValueError(f"Unknown dataset: {config.data.dataset}")

    return data_module
setup_logging_and_callbacks
setup_logging_and_callbacks(config)

Setup logging and callbacks.

Source code in embeddings_squeeze\squeeze.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def setup_logging_and_callbacks(config):
    """Setup logging and callbacks."""
    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Setup ClearML
    clearml_task = setup_clearml(
        project_name=config.logger.project_name,
        task_name=config.logger.task_name
    )

    # Create ClearML logger wrapper
    clearml_logger = ClearMLLogger(clearml_task) if clearml_task else None

    # TensorBoard logger for ClearML auto-logging
    pl_logger = TensorBoardLogger(
        save_dir=config.output_dir,
        name=config.experiment_name,
        version=None
    )

    # Callbacks
    checkpoint_dir = os.path.join(config.output_dir, config.experiment_name)
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{epoch:02d}-{val/loss:.2f}',
        monitor=config.training.monitor,
        mode=config.training.mode,
        save_top_k=config.training.save_top_k,
        save_last=True
    )

    early_stop_callback = EarlyStopping(
        monitor=config.training.monitor,
        mode=config.training.mode,
        patience=5,
        verbose=True
    )

    callbacks = [checkpoint_callback, early_stop_callback]

    # Add ClearML upload callback if task exists
    if clearml_task:
        clearml_upload_callback = ClearMLUploadCallback(
            task=clearml_task,
            clearml_logger=clearml_logger,
            checkpoint_dir=checkpoint_dir,
            embedding_dir="embeddings"
        )
        callbacks.append(clearml_upload_callback)
        print("ClearML logging and upload enabled")

    return pl_logger, clearml_logger, callbacks
main
main()

Main training function.

Source code in embeddings_squeeze\squeeze.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def main():
    """Main training function."""
    parser = argparse.ArgumentParser(description="VQ Compression Training")

    # Model arguments
    parser.add_argument("--model", type=str, default="vit", 
                       choices=["vit", "deeplab"], help="Backbone model")
    parser.add_argument("--num_classes", type=int, default=21,
                       help="Number of classes")
    parser.add_argument("--add_adapter", action="store_true",
                       help="Add adapter layers to frozen backbone")
    parser.add_argument("--feature_dim", type=int, default=None,
                       help="Feature dimension (auto-detected if not set)")
    parser.add_argument("--loss_type", type=str, default="ce",
                       choices=["ce", "dice", "focal", "combined"], help="Loss function type")

    # Quantizer arguments
    parser.add_argument("--quantizer_type", type=str, default="vq",
                       choices=["vq", "fsq", "lfq", "rvq", "none"], help="Quantizer type")
    parser.add_argument("--quantizer_enabled", action="store_true", default=True,
                       help="Enable quantization")
    parser.add_argument("--codebook_size", type=int, default=512,
                       help="Codebook size for VQ/LFQ/RVQ")
    parser.add_argument("--bottleneck_dim", type=int, default=64,
                       help="Bottleneck dimension for VQ/RVQ")
    parser.add_argument("--num_quantizers", type=int, default=4,
                       help="Number of quantizers for RVQ")

    # Logger arguments
    parser.add_argument("--use_clearml", action="store_true",
                       help="Use ClearML for logging")
    parser.add_argument("--project_name", type=str, default="embeddings_squeeze",
                       help="Project name for logging")
    parser.add_argument("--task_name", type=str, default=None,
                       help="Task name for logging (defaults to experiment_name)")

    # Training arguments
    parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--vq_loss_weight", type=float, default=0.1,
                       help="VQ loss weight")
    parser.add_argument("--max_batches", type=int, default=None,
                       help="Limit number of batches per epoch for train/val/test")

    # Data arguments
    parser.add_argument("--dataset", type=str, default="oxford_pet",
                       choices=["oxford_pet"], help="Dataset name")
    parser.add_argument("--data_path", type=str, default="./data",
                       help="Path to dataset")
    parser.add_argument("--subset_size", type=int, default=None,
                       help="Subset size for quick testing")

    # Experiment arguments
    parser.add_argument("--output_dir", type=str, default="./outputs",
                       help="Output directory")
    parser.add_argument("--experiment_name", type=str, default="vq_squeeze",
                       help="Experiment name")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")

    # Other arguments
    parser.add_argument("--initialize_codebook", action="store_true",
                       help="Initialize codebook with k-means")
    parser.add_argument("--max_init_samples", type=int, default=50000,
                       help="Max samples for codebook initialization")

    args = parser.parse_args()

    # Set random seeds
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    pl.seed_everything(args.seed)

    # Create configuration
    config = get_default_config()

    # Set task_name from experiment_name if not provided
    args_dict = vars(args)
    if args_dict.get('task_name') is None:
        args_dict['task_name'] = args_dict.get('experiment_name', 'vq_squeeze')

    config = update_config_from_args(config, args_dict)

    print(f"Starting experiment: {config.experiment_name}")
    print(f"Model: {config.model.backbone}")
    print(f"Dataset: {config.data.dataset}")
    print(f"Quantizer: {config.quantizer.type if config.quantizer.enabled else 'None'}")
    print(f"Loss type: {config.model.loss_type}")
    print(f"Epochs: {config.training.epochs}")

    # Create components
    # IMPORTANT: Create backbone first to auto-detect feature_dim
    backbone = create_backbone(config)

    # Now create quantizer with correct feature_dim
    quantizer = create_quantizer(config)

    data_module = create_data_module(config)

    # Setup logging and callbacks (do this before creating model to get clearml_logger)
    pl_logger, clearml_logger, callbacks = setup_logging_and_callbacks(config)

    model = VQSqueezeModule(
        backbone=backbone,
        quantizer=quantizer,
        num_classes=config.model.num_classes,
        learning_rate=config.training.learning_rate,
        vq_loss_weight=config.training.vq_loss_weight,
        loss_type=config.model.loss_type,
        class_weights=config.model.class_weights,
        add_adapter=config.model.add_adapter,
        feature_dim=config.model.feature_dim,
        clearml_logger=clearml_logger
    )

    # Setup data
    data_module.setup('fit')

    # Initialize codebook if requested (only for VQ-based quantizers)
    # Note: Codebook initialization is currently disabled in this version
    # if config.initialize_codebook and quantizer is not None:
    #     print("Initializing codebook with k-means...")
    #     initialize_codebook_from_data(
    #         quantizer,
    #         backbone,
    #         data_module.train_dataloader(max_batches=config.training.max_batches),
    #         model.device,
    #         max_samples=config.max_init_samples
    #     )

    # Create trainer
    trainer = pl.Trainer(
        max_epochs=config.training.epochs,
        logger=pl_logger,
        callbacks=callbacks,
        log_every_n_steps=config.training.log_every_n_steps,
        val_check_interval=config.training.val_check_interval,
        accelerator='auto', devices='auto',
        precision=16 if torch.cuda.is_available() else 32,
    )

    # Train
    print("Starting training...")
    trainer.fit(model, data_module)

    # Finalize ClearML logging
    if clearml_logger:
        print("Finalizing ClearML task...")
        clearml_logger.finalize()

test_integration

Test script to verify the complete VQ visualization workflow.

This script tests the integration of all components without running full training.

Functions

test_model_creation
test_model_creation()

Test that models can be created successfully.

Source code in embeddings_squeeze\test_integration.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def test_model_creation():
    """Test that models can be created successfully."""
    print("Testing model creation...")

    # Create backbone
    backbone = ViTSegmentationBackbone(num_classes=21, freeze_backbone=True)
    print(f"✓ Backbone created: {type(backbone).__name__}")

    # Create VQ model
    vq_model = VQSqueezeModule(
        backbone=backbone,
        num_vectors=128,
        commitment_cost=0.25,
        learning_rate=1e-4,
        vq_loss_weight=0.1
    )
    print(f"✓ VQ model created: {type(vq_model).__name__}")

    # Create baseline model
    baseline_model = BaselineSegmentationModule(
        backbone=backbone,
        learning_rate=1e-4
    )
    print(f"✓ Baseline model created: {type(baseline_model).__name__}")

    return vq_model, baseline_model
test_data_loading
test_data_loading()

Test that data can be loaded successfully.

Source code in embeddings_squeeze\test_integration.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def test_data_loading():
    """Test that data can be loaded successfully."""
    print("\nTesting data loading...")

    # Create data module with small subset
    data_module = OxfordPetDataModule(
        data_path="./data",
        batch_size=2,
        num_workers=0,
        pin_memory=False,
        image_size=224,
        subset_size=10  # Small subset for testing
    )

    try:
        data_module.setup('test')
        test_loader = data_module.test_dataloader()
        print(f"✓ Data module created with {len(test_loader.dataset)} samples")

        # Test loading a batch
        batch = next(iter(test_loader))
        images, masks = batch
        print(f"✓ Batch loaded: images {images.shape}, masks {masks.shape}")

        return test_loader
    except Exception as e:
        print(f"✗ Data loading failed: {e}")
        print("Note: This is expected if Oxford-IIIT Pet dataset is not downloaded")
        return None
test_model_inference
test_model_inference(vq_model, baseline_model, test_loader)

Test that models can run inference.

Source code in embeddings_squeeze\test_integration.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def test_model_inference(vq_model, baseline_model, test_loader):
    """Test that models can run inference."""
    print("\nTesting model inference...")

    if test_loader is None:
        print("Skipping inference test - no data available")
        return False

    device = torch.device("cpu")  # Use CPU for testing
    vq_model.to(device)
    baseline_model.to(device)

    try:
        # Get a batch
        images, masks = next(iter(test_loader))
        images = images.to(device)

        # Test VQ model inference
        with torch.no_grad():
            vq_output, vq_loss = vq_model(images)
            vq_preds = vq_model.predict_with_vq(images)
        print(f"✓ VQ model inference: output {vq_output.shape}, preds {vq_preds.shape}")

        # Test baseline model inference
        with torch.no_grad():
            baseline_output = baseline_model(images)
            baseline_preds = baseline_model.predict(images)
        print(f"✓ Baseline model inference: output {baseline_output.shape}, preds {baseline_preds.shape}")

        return True
    except Exception as e:
        print(f"✗ Model inference failed: {e}")
        return False
test_visualization_utilities
test_visualization_utilities()

Test visualization utilities.

Source code in embeddings_squeeze\test_integration.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def test_visualization_utilities():
    """Test visualization utilities."""
    print("\nTesting visualization utilities...")

    try:
        from utils.comparison import compute_sample_iou, find_best_worst_samples

        # Create dummy data
        pred = torch.randint(0, 21, (224, 224))
        target = torch.randint(0, 21, (224, 224))

        # Test IoU computation
        iou = compute_sample_iou(pred, target, num_classes=21)
        print(f"✓ IoU computation: {iou:.3f}")

        # Test sample ranking
        dummy_results = [
            (0, 0.8, torch.randn(3, 224, 224), torch.randint(0, 21, (224, 224)), torch.randint(0, 21, (224, 224))),
            (1, 0.3, torch.randn(3, 224, 224), torch.randint(0, 21, (224, 224)), torch.randint(0, 21, (224, 224))),
            (2, 0.9, torch.randn(3, 224, 224), torch.randint(0, 21, (224, 224)), torch.randint(0, 21, (224, 224))),
        ]

        best, worst = find_best_worst_samples(dummy_results, n_best=2, n_worst=1)
        print(f"✓ Sample ranking: {len(best)} best, {len(worst)} worst")

        return True
    except Exception as e:
        print(f"✗ Visualization utilities failed: {e}")
        return False
main
main()

Run all tests.

Source code in embeddings_squeeze\test_integration.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def main():
    """Run all tests."""
    print("="*60)
    print("VQ VISUALIZATION SYSTEM - INTEGRATION TEST")
    print("="*60)

    # Test model creation
    vq_model, baseline_model = test_model_creation()

    # Test data loading
    test_loader = test_data_loading()

    # Test model inference
    inference_ok = test_model_inference(vq_model, baseline_model, test_loader)

    # Test visualization utilities
    utils_ok = test_visualization_utilities()

    # Summary
    print("\n" + "="*60)
    print("TEST SUMMARY")
    print("="*60)
    print("✓ Model creation: PASSED")
    print("✓ Data loading: PASSED" if test_loader is not None else "⚠ Data loading: SKIPPED (no dataset)")
    print("✓ Model inference: PASSED" if inference_ok else "✗ Model inference: FAILED")
    print("✓ Visualization utilities: PASSED" if utils_ok else "✗ Visualization utilities: FAILED")

    if inference_ok and utils_ok:
        print("\n🎉 All core components are working correctly!")
        print("\nNext steps:")
        print("1. Train baseline model: python train_baseline.py --epochs 3 --subset_size 100")
        print("2. Train VQ model: python squeeze.py --epochs 3 --subset_size 100")
        print("3. Run visualization: python visualize.py --vq_checkpoint <path> --baseline_checkpoint <path>")
    else:
        print("\n❌ Some components failed. Please check the errors above.")

    print("="*60)

test_simplified_baseline

Test script to verify the simplified baseline training approach.

Functions

test_backbone_classifier_training
test_backbone_classifier_training()

Test that backbone classifiers are trainable when backbone is frozen.

Source code in embeddings_squeeze\test_simplified_baseline.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def test_backbone_classifier_training():
    """Test that backbone classifiers are trainable when backbone is frozen."""
    print("Testing backbone classifier training...")

    # Test ViT backbone
    print("\n--- ViT Backbone Test ---")
    vit_backbone = ViTSegmentationBackbone(num_classes=21, freeze_backbone=True)
    vit_model = BaselineSegmentationModule(vit_backbone)

    # Test DeepLab backbone
    print("\n--- DeepLab Backbone Test ---")
    deeplab_backbone = DeepLabV3SegmentationBackbone(num_classes=21, freeze_backbone=True)
    deeplab_model = BaselineSegmentationModule(deeplab_backbone)

    # Test inference
    print("\n--- Inference Test ---")
    dummy_input = torch.randn(1, 3, 224, 224)

    try:
        # Test ViT
        vit_output = vit_model(dummy_input)
        print(f"✓ ViT inference: {vit_output['out'].shape if isinstance(vit_output, dict) else vit_output.shape}")

        # Test DeepLab
        deeplab_output = deeplab_model(dummy_input)
        print(f"✓ DeepLab inference: {deeplab_output['out'].shape if isinstance(deeplab_output, dict) else deeplab_output.shape}")

        return True
    except Exception as e:
        print(f"✗ Inference failed: {e}")
        return False
test_optimizer_creation
test_optimizer_creation()

Test that optimizers can be created.

Source code in embeddings_squeeze\test_simplified_baseline.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def test_optimizer_creation():
    """Test that optimizers can be created."""
    print("\n--- Optimizer Test ---")

    try:
        vit_backbone = ViTSegmentationBackbone(num_classes=21, freeze_backbone=True)
        vit_model = BaselineSegmentationModule(vit_backbone)

        optimizer = vit_model.configure_optimizers()
        print(f"✓ Optimizer created: {type(optimizer).__name__}")

        return True
    except Exception as e:
        print(f"✗ Optimizer creation failed: {e}")
        return False
main
main()

Run all tests.

Source code in embeddings_squeeze\test_simplified_baseline.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def main():
    """Run all tests."""
    print("="*60)
    print("SIMPLIFIED BASELINE TRAINING TEST")
    print("="*60)

    # Test backbone classifier training
    backbone_ok = test_backbone_classifier_training()

    # Test optimizer creation
    optimizer_ok = test_optimizer_creation()

    # Summary
    print("\n" + "="*60)
    print("TEST SUMMARY")
    print("="*60)
    print("✓ Backbone classifier training: PASSED" if backbone_ok else "✗ Backbone classifier training: FAILED")
    print("✓ Optimizer creation: PASSED" if optimizer_ok else "✗ Optimizer creation: FAILED")

    if backbone_ok and optimizer_ok:
        print("\n🎉 Simplified baseline training approach is working!")
        print("The backbone classifiers are trainable while backbones remain frozen.")
    else:
        print("\n❌ Some tests failed. Check the errors above.")

    print("="*60)

train_baseline

CLI script for baseline segmentation training without VQ.

This script directly uses the existing backbone's trainable classifier head while keeping the backbone frozen.

Usage

python train_baseline.py --model vit --dataset oxford_pet --epochs 3

Functions

create_backbone
create_backbone(config)

Create segmentation backbone based on config.

Source code in embeddings_squeeze\train_baseline.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def create_backbone(config):
    """Create segmentation backbone based on config."""
    # Auto-detect feature_dim based on backbone if not set or invalid
    if config.model.backbone.lower() == "vit":
        # ViT uses 768-dim features
        if config.model.feature_dim is None or config.model.feature_dim == 2048:
            config.model.feature_dim = 768
        backbone = ViTSegmentationBackbone(
            num_classes=config.model.num_classes,
            freeze_backbone=True  # Always freeze backbone, train only classifier
        )
    elif config.model.backbone.lower() == "deeplab":
        # DeepLab uses 2048-dim features
        if config.model.feature_dim is None:
            config.model.feature_dim = 2048
        backbone = DeepLabV3SegmentationBackbone(
            weights_name=config.model.deeplab_weights,
            num_classes=config.model.num_classes,
            freeze_backbone=True  # Always freeze backbone, train only classifier
        )
    else:
        raise ValueError(f"Unknown backbone: {config.model.backbone}")

    return backbone
create_data_module
create_data_module(config)

Create data module based on config.

Source code in embeddings_squeeze\train_baseline.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def create_data_module(config):
    """Create data module based on config."""
    if config.data.dataset.lower() == "oxford_pet":
        data_module = OxfordPetDataModule(
            data_path=config.data.data_path,
            batch_size=config.training.batch_size,
            num_workers=config.training.num_workers,
            pin_memory=config.training.pin_memory,
            image_size=config.data.image_size,
            subset_size=config.data.subset_size
        )
    else:
        raise ValueError(f"Unknown dataset: {config.data.dataset}")

    return data_module
setup_logging_and_callbacks
setup_logging_and_callbacks(config)

Setup logging and callbacks.

Source code in embeddings_squeeze\train_baseline.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def setup_logging_and_callbacks(config):
    """Setup logging and callbacks."""
    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Setup ClearML
    if config.logger.use_clearml:
        clearml_task = setup_clearml(
            project_name=config.logger.project_name,
            task_name=config.logger.task_name
        )
    else:
        clearml_task = None

    # Create ClearML logger wrapper
    clearml_logger = ClearMLLogger(clearml_task) if clearml_task else None

    # PyTorch Lightning logger (None for ClearML auto-logging, TensorBoard otherwise)
    if clearml_task:
        pl_logger = None
        print("Using ClearML for logging")
    else:
        pl_logger = TensorBoardLogger(
            save_dir=config.output_dir,
            name=config.experiment_name,
            version=None
        )
        print("Using TensorBoard for logging")

    # Callbacks
    monitor_metric = 'val/loss'  # Unified metric name
    checkpoint_dir = os.path.join(config.output_dir, config.experiment_name)

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{epoch:02d}-{val/loss:.2f}',
        monitor=monitor_metric,
        mode='min',  # Minimize loss
        save_top_k=config.training.save_top_k,
        save_last=True
    )

    early_stop_callback = EarlyStopping(
        monitor=monitor_metric,
        mode='min',  # Minimize loss
        patience=5,
        verbose=True
    )

    callbacks = [checkpoint_callback, early_stop_callback]

    # Add ClearML upload callback if task exists
    if clearml_task:
        clearml_upload_callback = ClearMLUploadCallback(
            task=clearml_task,
            clearml_logger=clearml_logger,
            checkpoint_dir=checkpoint_dir,
            embedding_dir="embeddings"
        )
        callbacks.append(clearml_upload_callback)
        print("ClearML upload callback enabled")

    return pl_logger, clearml_logger, callbacks
main
main()

Main training function.

Source code in embeddings_squeeze\train_baseline.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def main():
    """Main training function."""
    parser = argparse.ArgumentParser(description="Baseline Segmentation Training")

    # Model arguments
    parser.add_argument("--model", type=str, default="vit", 
                       choices=["vit", "deeplab"], help="Backbone model")
    parser.add_argument("--num_classes", type=int, default=21,
                       help="Number of classes")
    parser.add_argument("--loss_type", type=str, default="ce",
                       choices=["ce", "dice", "focal", "combined"], help="Loss function type")

    # Logger arguments
    parser.add_argument("--use_clearml", action="store_true",
                       help="Use ClearML for logging")
    parser.add_argument("--project_name", type=str, default="embeddings_squeeze",
                       help="Project name for logging")
    parser.add_argument("--task_name", type=str, default=None,
                       help="Task name for logging (defaults to experiment_name)")

    # Training arguments
    parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--max_batches", type=int, default=None,
                       help="Limit number of batches per epoch for train/val/test")

    # Data arguments
    parser.add_argument("--dataset", type=str, default="oxford_pet",
                       choices=["oxford_pet"], help="Dataset name")
    parser.add_argument("--data_path", type=str, default="./data",
                       help="Path to dataset")
    parser.add_argument("--subset_size", type=int, default=None,
                       help="Subset size for quick testing")

    # Experiment arguments
    parser.add_argument("--output_dir", type=str, default="./outputs",
                       help="Output directory")
    parser.add_argument("--experiment_name", type=str, default="segmentation_baseline",
                       help="Experiment name")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")

    args = parser.parse_args()

    # Set random seeds
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    pl.seed_everything(args.seed)

    # Create configuration
    config = get_default_config()

    # Set task_name from experiment_name if not provided
    args_dict = vars(args)
    if args_dict.get('task_name') is None:
        args_dict['task_name'] = args_dict.get('experiment_name', 'segmentation_baseline')

    config = update_config_from_args(config, args_dict)

    # Override experiment name to indicate baseline
    config.experiment_name = f"{config.experiment_name}"

    print(f"Starting baseline experiment: {config.experiment_name}")
    print(f"Model: {config.model.backbone}")
    print(f"Dataset: {config.data.dataset}")
    print(f"Loss type: {config.model.loss_type}")
    print(f"Epochs: {config.training.epochs}")
    print("Training strategy: Frozen backbone + trainable classifier")

    # Create components
    backbone = create_backbone(config)
    data_module = create_data_module(config)

    # Setup logging and callbacks (do this before creating model to get clearml_logger)
    pl_logger, clearml_logger, callbacks = setup_logging_and_callbacks(config)

    model = BaselineSegmentationModule(
        backbone=backbone,
        num_classes=config.model.num_classes,
        learning_rate=config.training.learning_rate,
        loss_type=config.model.loss_type,
        class_weights=config.model.class_weights,
        clearml_logger=clearml_logger
    )

    # Setup data
    data_module.setup('fit')

    # Create trainer
    trainer = pl.Trainer(
        max_epochs=config.training.epochs,
        logger=pl_logger,
        callbacks=callbacks,
        log_every_n_steps=config.training.log_every_n_steps,
        val_check_interval=config.training.val_check_interval,
        accelerator='auto', devices='auto',
        precision=16 if torch.cuda.is_available() else 32,
    )

    # Train
    print("Starting training...")
    trainer.fit(model, data_module)

    # Finalize ClearML logging
    if clearml_logger:
        print("Finalizing ClearML task...")
        clearml_logger.finalize()

    print(f"Baseline training completed!")
    print(f"Results saved to: {config.output_dir}/{config.experiment_name}")

utils

Utility functions for VQ compression.

Functions

measure_compression
measure_compression(
    vq_model, backbone, test_loader, device
)

Measure compression ratio achieved by VQ.

Parameters:

Name Type Description Default
vq_model

VectorQuantizer model

required
backbone

Segmentation backbone

required
test_loader

Test data loader

required
device

Device to run on

required

Returns:

Name Type Description
compression_ratio

Compression ratio achieved

Source code in embeddings_squeeze\utils\compression.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def measure_compression(vq_model, backbone, test_loader, device):
    """
    Measure compression ratio achieved by VQ.

    Args:
        vq_model: VectorQuantizer model
        backbone: Segmentation backbone
        test_loader: Test data loader
        device: Device to run on

    Returns:
        compression_ratio: Compression ratio achieved
    """
    vq_model.eval()

    total_original_bits = 0
    total_compressed_bits = 0
    num_features = 0

    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            features = backbone.extract_features(images)

            B, C, H, W = features.shape
            feat_flat = features.permute(0, 2, 3, 1).reshape(-1, C)

            # Original size (float32 = 32 bits)
            original_bits = feat_flat.numel() * 32

            # Compressed size (only indices)
            num_codes = vq_model.num_vectors
            bits_per_index = np.ceil(np.log2(num_codes))
            compressed_bits = feat_flat.shape[0] * bits_per_index

            total_original_bits += original_bits
            total_compressed_bits += compressed_bits
            num_features += feat_flat.shape[0]

    compression_ratio = total_original_bits / total_compressed_bits

    print("="*70)
    print("COMPRESSION ANALYSIS")
    print("="*70)
    print(f"Total features processed: {num_features}")
    print(f"Feature dimension: {vq_model.vector_dim}")
    print(f"Codebook size: {vq_model.num_vectors}")
    print()
    print(f"Original storage:")
    print(f"  Per feature: {vq_model.vector_dim} × 32 bits = {vq_model.vector_dim * 32} bits = {vq_model.vector_dim * 4} bytes")
    print(f"  Total: {total_original_bits / 8 / 1024 / 1024:.2f} MB")
    print()
    print(f"Compressed storage:")
    print(f"  Per feature: {bits_per_index:.0f} bits (index)")
    print(f"  Total: {total_compressed_bits / 8 / 1024:.2f} KB")
    print(f"  + Codebook: {vq_model.num_vectors * vq_model.vector_dim * 4 / 1024:.2f} KB")
    print()
    print(f"Compression ratio: {compression_ratio:.1f}x")
    print(f"Space savings: {(1 - 1/compression_ratio)*100:.1f}%")
    print("="*70)

    return compression_ratio
compute_iou_metrics
compute_iou_metrics(
    predictions, targets, num_classes, ignore_index=255
)

Compute IoU metrics for segmentation.

Parameters:

Name Type Description Default
predictions

Predicted masks [B, H, W]

required
targets

Ground truth masks [B, H, W]

required
num_classes

Number of classes

required
ignore_index

Index to ignore in computation

255

Returns:

Name Type Description
metrics

Dictionary with IoU metrics

Source code in embeddings_squeeze\utils\compression.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def compute_iou_metrics(predictions, targets, num_classes, ignore_index=255):
    """
    Compute IoU metrics for segmentation.

    Args:
        predictions: Predicted masks [B, H, W]
        targets: Ground truth masks [B, H, W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation

    Returns:
        metrics: Dictionary with IoU metrics
    """
    predictions = predictions.cpu().numpy()
    targets = targets.cpu().numpy()

    # Flatten arrays
    pred_flat = predictions.flatten()
    target_flat = targets.flatten()

    # Remove ignored pixels
    valid_mask = target_flat != ignore_index
    pred_flat = pred_flat[valid_mask]
    target_flat = target_flat[valid_mask]

    # Compute IoU for each class
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred_flat == cls)
        target_cls = (target_flat == cls)

        intersection = (pred_cls & target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection

        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0

        ious.append(iou)

    # Compute mean IoU
    mean_iou = np.mean(ious)

    # Compute pixel accuracy
    pixel_acc = (pred_flat == target_flat).mean()

    metrics = {
        'mean_iou': mean_iou,
        'pixel_accuracy': pixel_acc,
        'class_ious': ious
    }

    return metrics
initialize_codebook_from_data
initialize_codebook_from_data(
    vq_model,
    backbone,
    train_loader,
    device,
    max_samples=50000,
)

Initialize codebook using k-means clustering on real data.

Parameters:

Name Type Description Default
vq_model

VectorQuantizer model

required
backbone

Segmentation backbone

required
train_loader

Training data loader

required
device

Device to run on

required
max_samples int

Maximum number of samples for k-means

50000
Source code in embeddings_squeeze\utils\initialization.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def initialize_codebook_from_data(
    vq_model, 
    backbone, 
    train_loader, 
    device, 
    max_samples: int = 50_000
):
    """
    Initialize codebook using k-means clustering on real data.

    Args:
        vq_model: VectorQuantizer model
        backbone: Segmentation backbone
        train_loader: Training data loader
        device: Device to run on
        max_samples: Maximum number of samples for k-means
    """
    print("Collecting features for k-means initialization...")
    all_features = []

    backbone.eval()
    with torch.no_grad():
        i = 0
        for images, _ in train_loader:
            print(f"Processing batch {i}")
            features = backbone.extract_features(images.to(device))
            i += 1

            B, C, H, W = features.shape
            feat_flat = features.permute(0, 2, 3, 1).reshape(-1, C)
            all_features.append(feat_flat.cpu())

            if len(all_features) * feat_flat.shape[0] > max_samples:
                break

    all_features = torch.cat(all_features).numpy()

    print(f"Running k-means on {all_features.shape[0]} features...")
    kmeans = MiniBatchKMeans(
        n_clusters=vq_model.num_vectors,
        random_state=0,
        batch_size=1000,
        max_iter=100
    )
    kmeans.fit(all_features)

    # Update codebook with cluster centers
    vq_model.codebook.embeddings.data = torch.tensor(kmeans.cluster_centers_).to(device).float()

    print(f"Codebook initialized:")
    print(f"  Mean: {vq_model.codebook.embeddings.mean():.2f}")
    print(f"  Std: {vq_model.codebook.embeddings.std():.2f}")
    print(f"  Norm: {vq_model.codebook.embeddings.norm(dim=1).mean():.2f}")
compute_sample_iou
compute_sample_iou(
    prediction, target, num_classes, ignore_index=255
)

Compute IoU for a single sample.

Parameters:

Name Type Description Default
prediction Tensor

Predicted mask [H, W]

required
target Tensor

Ground truth mask [H, W]

required
num_classes int

Number of classes

required
ignore_index int

Index to ignore in computation

255

Returns:

Name Type Description
iou float

Mean IoU across all classes

Source code in embeddings_squeeze\utils\comparison.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def compute_sample_iou(prediction: torch.Tensor, target: torch.Tensor, num_classes: int, ignore_index: int = 255) -> float:
    """
    Compute IoU for a single sample.

    Args:
        prediction: Predicted mask [H, W]
        target: Ground truth mask [H, W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation

    Returns:
        iou: Mean IoU across all classes
    """
    prediction = prediction.cpu().numpy()
    target = target.cpu().numpy()

    # Flatten arrays
    pred_flat = prediction.flatten()
    target_flat = target.flatten()

    # Remove ignored pixels
    valid_mask = target_flat != ignore_index
    pred_flat = pred_flat[valid_mask]
    target_flat = target_flat[valid_mask]

    # Compute IoU for each class
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred_flat == cls)
        target_cls = (target_flat == cls)

        intersection = (pred_cls & target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection

        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0

        ious.append(iou)

    # Return mean IoU
    return np.mean(ious)
evaluate_model
evaluate_model(model, dataloader, device, num_classes=21)

Evaluate model on dataset and collect results.

Parameters:

Name Type Description Default
model

Model to evaluate

required
dataloader

Data loader

required
device

Device to run on

required
num_classes int

Number of classes

21

Returns:

Name Type Description
results List[Tuple[int, float, Tensor, Tensor, Tensor]]

List of (sample_idx, iou, image, mask, prediction) tuples

Source code in embeddings_squeeze\utils\comparison.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def evaluate_model(model, dataloader, device, num_classes: int = 21) -> List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    Evaluate model on dataset and collect results.

    Args:
        model: Model to evaluate
        dataloader: Data loader
        device: Device to run on
        num_classes: Number of classes

    Returns:
        results: List of (sample_idx, iou, image, mask, prediction) tuples
    """
    model.eval()
    results = []

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.squeeze(1).long().to(device)

            # Get predictions
            if hasattr(model, 'predict'):
                predictions = model.predict(images)
            else:
                # For VQ model
                predictions = model.predict_with_vq(images)

            # Process each sample in batch
            for i in range(images.shape[0]):
                image = images[i]
                mask = masks[i]
                pred = predictions[i]

                # Compute IoU
                iou = compute_sample_iou(pred, mask, num_classes)

                # Store results
                sample_idx = batch_idx * dataloader.batch_size + i
                results.append((sample_idx, iou, image, mask, pred))

    return results
find_best_worst_samples
find_best_worst_samples(results, n_best=5, n_worst=5)

Find best and worst samples based on IoU.

Parameters:

Name Type Description Default
results List[Tuple[int, float, Tensor, Tensor, Tensor]]

List of (sample_idx, iou, image, mask, prediction) tuples

required
n_best int

Number of best samples to return

5
n_worst int

Number of worst samples to return

5

Returns:

Name Type Description
best_samples List

List of best sample tuples

worst_samples List

List of worst sample tuples

Source code in embeddings_squeeze\utils\comparison.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def find_best_worst_samples(results: List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor]], 
                           n_best: int = 5, n_worst: int = 5) -> Tuple[List, List]:
    """
    Find best and worst samples based on IoU.

    Args:
        results: List of (sample_idx, iou, image, mask, prediction) tuples
        n_best: Number of best samples to return
        n_worst: Number of worst samples to return

    Returns:
        best_samples: List of best sample tuples
        worst_samples: List of worst sample tuples
    """
    # Sort by IoU (descending)
    sorted_results = sorted(results, key=lambda x: x[1], reverse=True)

    # Get best and worst
    best_samples = sorted_results[:n_best]
    worst_samples = sorted_results[-n_worst:]

    return best_samples, worst_samples
prepare_visualization_data
prepare_visualization_data(
    vq_model,
    baseline_model,
    dataloader,
    device,
    num_classes=21,
    n_best=5,
    n_worst=5,
)

Prepare data for visualization by running both models and ranking results.

Parameters:

Name Type Description Default
vq_model

VQ model

required
baseline_model

Baseline model

required
dataloader

Data loader

required
device

Device to run on

required
num_classes int

Number of classes

21
n_best int

Number of best samples

5
n_worst int

Number of worst samples

5

Returns:

Name Type Description
best_samples

List of best sample tuples with both predictions

worst_samples

List of worst sample tuples with both predictions

Source code in embeddings_squeeze\utils\comparison.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def prepare_visualization_data(vq_model, baseline_model, dataloader, device, 
                              num_classes: int = 21, n_best: int = 5, n_worst: int = 5):
    """
    Prepare data for visualization by running both models and ranking results.

    Args:
        vq_model: VQ model
        baseline_model: Baseline model
        dataloader: Data loader
        device: Device to run on
        num_classes: Number of classes
        n_best: Number of best samples
        n_worst: Number of worst samples

    Returns:
        best_samples: List of best sample tuples with both predictions
        worst_samples: List of worst sample tuples with both predictions
    """
    # Evaluate VQ model
    print("Evaluating VQ model...")
    vq_results = evaluate_model(vq_model, dataloader, device, num_classes)

    # Evaluate baseline model
    print("Evaluating baseline model...")
    baseline_results = evaluate_model(baseline_model, dataloader, device, num_classes)

    # Combine results (assuming same order)
    combined_results = []
    for (idx1, iou1, img1, mask1, pred_vq), (idx2, iou2, img2, mask2, pred_baseline) in zip(vq_results, baseline_results):
        assert idx1 == idx2, "Sample indices don't match"
        assert torch.equal(img1, img2), "Images don't match"
        assert torch.equal(mask1, mask2), "Masks don't match"

        combined_results.append((idx1, iou1, img1, mask1, pred_baseline, pred_vq))

    # Find best and worst based on VQ IoU
    best_samples, worst_samples = find_best_worst_samples(combined_results, n_best, n_worst)

    return best_samples, worst_samples
visualize_comparison
visualize_comparison(
    samples, title, output_path, num_classes=21
)

Create visualization comparing baseline and VQ predictions.

Parameters:

Name Type Description Default
samples List[Tuple[int, float, Tensor, Tensor, Tensor, Tensor]]

List of (idx, iou, image, mask, pred_baseline, pred_vq) tuples

required
title str

Figure title

required
output_path str

Path to save figure

required
num_classes int

Number of classes

21
Source code in embeddings_squeeze\utils\comparison.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def visualize_comparison(samples: List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], 
                        title: str, output_path: str, num_classes: int = 21):
    """
    Create visualization comparing baseline and VQ predictions.

    Args:
        samples: List of (idx, iou, image, mask, pred_baseline, pred_vq) tuples
        title: Figure title
        output_path: Path to save figure
        num_classes: Number of classes
    """
    n_samples = len(samples)
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))

    if n_samples == 1:
        axes = axes.reshape(1, -1)

    colormap = create_segmentation_colormap(num_classes)

    for i, (idx, iou, image, mask, pred_baseline, pred_vq) in enumerate(samples):
        # Denormalize image
        image_vis = denormalize_image(image)

        # Convert to numpy for plotting
        image_np = image_vis.permute(1, 2, 0).cpu().numpy()
        mask_np = mask.cpu().numpy()
        pred_baseline_np = pred_baseline.cpu().numpy()
        pred_vq_np = pred_vq.cpu().numpy()

        # Plot original image
        axes[i, 0].imshow(image_np)
        axes[i, 0].set_title(f"Original Image\nSample {idx}")
        axes[i, 0].axis('off')

        # Plot ground truth
        axes[i, 1].imshow(mask_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')

        # Plot baseline prediction
        axes[i, 2].imshow(pred_baseline_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 2].set_title("Baseline Prediction")
        axes[i, 2].axis('off')

        # Plot VQ prediction
        axes[i, 3].imshow(pred_vq_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 3].set_title(f"VQ Prediction\nIoU: {iou:.3f}")
        axes[i, 3].axis('off')

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()

    print(f"Visualization saved to: {output_path}")

Modules

comparison

Comparison utilities for VQ vs baseline segmentation evaluation.

Functions
compute_sample_iou
compute_sample_iou(
    prediction, target, num_classes, ignore_index=255
)

Compute IoU for a single sample.

Parameters:

Name Type Description Default
prediction Tensor

Predicted mask [H, W]

required
target Tensor

Ground truth mask [H, W]

required
num_classes int

Number of classes

required
ignore_index int

Index to ignore in computation

255

Returns:

Name Type Description
iou float

Mean IoU across all classes

Source code in embeddings_squeeze\utils\comparison.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def compute_sample_iou(prediction: torch.Tensor, target: torch.Tensor, num_classes: int, ignore_index: int = 255) -> float:
    """
    Compute IoU for a single sample.

    Args:
        prediction: Predicted mask [H, W]
        target: Ground truth mask [H, W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation

    Returns:
        iou: Mean IoU across all classes
    """
    prediction = prediction.cpu().numpy()
    target = target.cpu().numpy()

    # Flatten arrays
    pred_flat = prediction.flatten()
    target_flat = target.flatten()

    # Remove ignored pixels
    valid_mask = target_flat != ignore_index
    pred_flat = pred_flat[valid_mask]
    target_flat = target_flat[valid_mask]

    # Compute IoU for each class
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred_flat == cls)
        target_cls = (target_flat == cls)

        intersection = (pred_cls & target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection

        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0

        ious.append(iou)

    # Return mean IoU
    return np.mean(ious)
evaluate_model
evaluate_model(model, dataloader, device, num_classes=21)

Evaluate model on dataset and collect results.

Parameters:

Name Type Description Default
model

Model to evaluate

required
dataloader

Data loader

required
device

Device to run on

required
num_classes int

Number of classes

21

Returns:

Name Type Description
results List[Tuple[int, float, Tensor, Tensor, Tensor]]

List of (sample_idx, iou, image, mask, prediction) tuples

Source code in embeddings_squeeze\utils\comparison.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def evaluate_model(model, dataloader, device, num_classes: int = 21) -> List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    Evaluate model on dataset and collect results.

    Args:
        model: Model to evaluate
        dataloader: Data loader
        device: Device to run on
        num_classes: Number of classes

    Returns:
        results: List of (sample_idx, iou, image, mask, prediction) tuples
    """
    model.eval()
    results = []

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.squeeze(1).long().to(device)

            # Get predictions
            if hasattr(model, 'predict'):
                predictions = model.predict(images)
            else:
                # For VQ model
                predictions = model.predict_with_vq(images)

            # Process each sample in batch
            for i in range(images.shape[0]):
                image = images[i]
                mask = masks[i]
                pred = predictions[i]

                # Compute IoU
                iou = compute_sample_iou(pred, mask, num_classes)

                # Store results
                sample_idx = batch_idx * dataloader.batch_size + i
                results.append((sample_idx, iou, image, mask, pred))

    return results
find_best_worst_samples
find_best_worst_samples(results, n_best=5, n_worst=5)

Find best and worst samples based on IoU.

Parameters:

Name Type Description Default
results List[Tuple[int, float, Tensor, Tensor, Tensor]]

List of (sample_idx, iou, image, mask, prediction) tuples

required
n_best int

Number of best samples to return

5
n_worst int

Number of worst samples to return

5

Returns:

Name Type Description
best_samples List

List of best sample tuples

worst_samples List

List of worst sample tuples

Source code in embeddings_squeeze\utils\comparison.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def find_best_worst_samples(results: List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor]], 
                           n_best: int = 5, n_worst: int = 5) -> Tuple[List, List]:
    """
    Find best and worst samples based on IoU.

    Args:
        results: List of (sample_idx, iou, image, mask, prediction) tuples
        n_best: Number of best samples to return
        n_worst: Number of worst samples to return

    Returns:
        best_samples: List of best sample tuples
        worst_samples: List of worst sample tuples
    """
    # Sort by IoU (descending)
    sorted_results = sorted(results, key=lambda x: x[1], reverse=True)

    # Get best and worst
    best_samples = sorted_results[:n_best]
    worst_samples = sorted_results[-n_worst:]

    return best_samples, worst_samples
prepare_visualization_data
prepare_visualization_data(
    vq_model,
    baseline_model,
    dataloader,
    device,
    num_classes=21,
    n_best=5,
    n_worst=5,
)

Prepare data for visualization by running both models and ranking results.

Parameters:

Name Type Description Default
vq_model

VQ model

required
baseline_model

Baseline model

required
dataloader

Data loader

required
device

Device to run on

required
num_classes int

Number of classes

21
n_best int

Number of best samples

5
n_worst int

Number of worst samples

5

Returns:

Name Type Description
best_samples

List of best sample tuples with both predictions

worst_samples

List of worst sample tuples with both predictions

Source code in embeddings_squeeze\utils\comparison.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def prepare_visualization_data(vq_model, baseline_model, dataloader, device, 
                              num_classes: int = 21, n_best: int = 5, n_worst: int = 5):
    """
    Prepare data for visualization by running both models and ranking results.

    Args:
        vq_model: VQ model
        baseline_model: Baseline model
        dataloader: Data loader
        device: Device to run on
        num_classes: Number of classes
        n_best: Number of best samples
        n_worst: Number of worst samples

    Returns:
        best_samples: List of best sample tuples with both predictions
        worst_samples: List of worst sample tuples with both predictions
    """
    # Evaluate VQ model
    print("Evaluating VQ model...")
    vq_results = evaluate_model(vq_model, dataloader, device, num_classes)

    # Evaluate baseline model
    print("Evaluating baseline model...")
    baseline_results = evaluate_model(baseline_model, dataloader, device, num_classes)

    # Combine results (assuming same order)
    combined_results = []
    for (idx1, iou1, img1, mask1, pred_vq), (idx2, iou2, img2, mask2, pred_baseline) in zip(vq_results, baseline_results):
        assert idx1 == idx2, "Sample indices don't match"
        assert torch.equal(img1, img2), "Images don't match"
        assert torch.equal(mask1, mask2), "Masks don't match"

        combined_results.append((idx1, iou1, img1, mask1, pred_baseline, pred_vq))

    # Find best and worst based on VQ IoU
    best_samples, worst_samples = find_best_worst_samples(combined_results, n_best, n_worst)

    return best_samples, worst_samples
denormalize_image
denormalize_image(
    image,
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
)

Denormalize image for visualization.

Parameters:

Name Type Description Default
image Tensor

Normalized image tensor [C, H, W]

required
mean Tuple[float, float, float]

Normalization mean

(0.485, 0.456, 0.406)
std Tuple[float, float, float]

Normalization std

(0.229, 0.224, 0.225)

Returns:

Name Type Description
denormalized Tensor

Denormalized image tensor

Source code in embeddings_squeeze\utils\comparison.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def denormalize_image(image: torch.Tensor, mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), 
                     std: Tuple[float, float, float] = (0.229, 0.224, 0.225)) -> torch.Tensor:
    """
    Denormalize image for visualization.

    Args:
        image: Normalized image tensor [C, H, W]
        mean: Normalization mean
        std: Normalization std

    Returns:
        denormalized: Denormalized image tensor
    """
    image = image.clone()
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)
    return torch.clamp(image, 0, 1)
create_segmentation_colormap
create_segmentation_colormap(num_classes=21)

Create a colormap for segmentation visualization.

Parameters:

Name Type Description Default
num_classes int

Number of classes

21

Returns:

Name Type Description
colormap ListedColormap

Matplotlib colormap

Source code in embeddings_squeeze\utils\comparison.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def create_segmentation_colormap(num_classes: int = 21) -> ListedColormap:
    """
    Create a colormap for segmentation visualization.

    Args:
        num_classes: Number of classes

    Returns:
        colormap: Matplotlib colormap
    """
    # Generate distinct colors
    colors = sns.color_palette("husl", num_classes)
    colors = [(0, 0, 0)] + colors  # Add black for background
    return ListedColormap(colors)
visualize_comparison
visualize_comparison(
    samples, title, output_path, num_classes=21
)

Create visualization comparing baseline and VQ predictions.

Parameters:

Name Type Description Default
samples List[Tuple[int, float, Tensor, Tensor, Tensor, Tensor]]

List of (idx, iou, image, mask, pred_baseline, pred_vq) tuples

required
title str

Figure title

required
output_path str

Path to save figure

required
num_classes int

Number of classes

21
Source code in embeddings_squeeze\utils\comparison.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def visualize_comparison(samples: List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], 
                        title: str, output_path: str, num_classes: int = 21):
    """
    Create visualization comparing baseline and VQ predictions.

    Args:
        samples: List of (idx, iou, image, mask, pred_baseline, pred_vq) tuples
        title: Figure title
        output_path: Path to save figure
        num_classes: Number of classes
    """
    n_samples = len(samples)
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))

    if n_samples == 1:
        axes = axes.reshape(1, -1)

    colormap = create_segmentation_colormap(num_classes)

    for i, (idx, iou, image, mask, pred_baseline, pred_vq) in enumerate(samples):
        # Denormalize image
        image_vis = denormalize_image(image)

        # Convert to numpy for plotting
        image_np = image_vis.permute(1, 2, 0).cpu().numpy()
        mask_np = mask.cpu().numpy()
        pred_baseline_np = pred_baseline.cpu().numpy()
        pred_vq_np = pred_vq.cpu().numpy()

        # Plot original image
        axes[i, 0].imshow(image_np)
        axes[i, 0].set_title(f"Original Image\nSample {idx}")
        axes[i, 0].axis('off')

        # Plot ground truth
        axes[i, 1].imshow(mask_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')

        # Plot baseline prediction
        axes[i, 2].imshow(pred_baseline_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 2].set_title("Baseline Prediction")
        axes[i, 2].axis('off')

        # Plot VQ prediction
        axes[i, 3].imshow(pred_vq_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 3].set_title(f"VQ Prediction\nIoU: {iou:.3f}")
        axes[i, 3].axis('off')

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()

    print(f"Visualization saved to: {output_path}")
compression

Compression analysis and metrics utilities.

Functions
measure_compression
measure_compression(
    vq_model, backbone, test_loader, device
)

Measure compression ratio achieved by VQ.

Parameters:

Name Type Description Default
vq_model

VectorQuantizer model

required
backbone

Segmentation backbone

required
test_loader

Test data loader

required
device

Device to run on

required

Returns:

Name Type Description
compression_ratio

Compression ratio achieved

Source code in embeddings_squeeze\utils\compression.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def measure_compression(vq_model, backbone, test_loader, device):
    """
    Measure compression ratio achieved by VQ.

    Args:
        vq_model: VectorQuantizer model
        backbone: Segmentation backbone
        test_loader: Test data loader
        device: Device to run on

    Returns:
        compression_ratio: Compression ratio achieved
    """
    vq_model.eval()

    total_original_bits = 0
    total_compressed_bits = 0
    num_features = 0

    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            features = backbone.extract_features(images)

            B, C, H, W = features.shape
            feat_flat = features.permute(0, 2, 3, 1).reshape(-1, C)

            # Original size (float32 = 32 bits)
            original_bits = feat_flat.numel() * 32

            # Compressed size (only indices)
            num_codes = vq_model.num_vectors
            bits_per_index = np.ceil(np.log2(num_codes))
            compressed_bits = feat_flat.shape[0] * bits_per_index

            total_original_bits += original_bits
            total_compressed_bits += compressed_bits
            num_features += feat_flat.shape[0]

    compression_ratio = total_original_bits / total_compressed_bits

    print("="*70)
    print("COMPRESSION ANALYSIS")
    print("="*70)
    print(f"Total features processed: {num_features}")
    print(f"Feature dimension: {vq_model.vector_dim}")
    print(f"Codebook size: {vq_model.num_vectors}")
    print()
    print(f"Original storage:")
    print(f"  Per feature: {vq_model.vector_dim} × 32 bits = {vq_model.vector_dim * 32} bits = {vq_model.vector_dim * 4} bytes")
    print(f"  Total: {total_original_bits / 8 / 1024 / 1024:.2f} MB")
    print()
    print(f"Compressed storage:")
    print(f"  Per feature: {bits_per_index:.0f} bits (index)")
    print(f"  Total: {total_compressed_bits / 8 / 1024:.2f} KB")
    print(f"  + Codebook: {vq_model.num_vectors * vq_model.vector_dim * 4 / 1024:.2f} KB")
    print()
    print(f"Compression ratio: {compression_ratio:.1f}x")
    print(f"Space savings: {(1 - 1/compression_ratio)*100:.1f}%")
    print("="*70)

    return compression_ratio
compute_iou_metrics
compute_iou_metrics(
    predictions, targets, num_classes, ignore_index=255
)

Compute IoU metrics for segmentation.

Parameters:

Name Type Description Default
predictions

Predicted masks [B, H, W]

required
targets

Ground truth masks [B, H, W]

required
num_classes

Number of classes

required
ignore_index

Index to ignore in computation

255

Returns:

Name Type Description
metrics

Dictionary with IoU metrics

Source code in embeddings_squeeze\utils\compression.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def compute_iou_metrics(predictions, targets, num_classes, ignore_index=255):
    """
    Compute IoU metrics for segmentation.

    Args:
        predictions: Predicted masks [B, H, W]
        targets: Ground truth masks [B, H, W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation

    Returns:
        metrics: Dictionary with IoU metrics
    """
    predictions = predictions.cpu().numpy()
    targets = targets.cpu().numpy()

    # Flatten arrays
    pred_flat = predictions.flatten()
    target_flat = targets.flatten()

    # Remove ignored pixels
    valid_mask = target_flat != ignore_index
    pred_flat = pred_flat[valid_mask]
    target_flat = target_flat[valid_mask]

    # Compute IoU for each class
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred_flat == cls)
        target_cls = (target_flat == cls)

        intersection = (pred_cls & target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection

        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0

        ious.append(iou)

    # Compute mean IoU
    mean_iou = np.mean(ious)

    # Compute pixel accuracy
    pixel_acc = (pred_flat == target_flat).mean()

    metrics = {
        'mean_iou': mean_iou,
        'pixel_accuracy': pixel_acc,
        'class_ious': ious
    }

    return metrics
initialization

Codebook initialization utilities.

Functions
initialize_codebook_from_data
initialize_codebook_from_data(
    vq_model,
    backbone,
    train_loader,
    device,
    max_samples=50000,
)

Initialize codebook using k-means clustering on real data.

Parameters:

Name Type Description Default
vq_model

VectorQuantizer model

required
backbone

Segmentation backbone

required
train_loader

Training data loader

required
device

Device to run on

required
max_samples int

Maximum number of samples for k-means

50000
Source code in embeddings_squeeze\utils\initialization.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def initialize_codebook_from_data(
    vq_model, 
    backbone, 
    train_loader, 
    device, 
    max_samples: int = 50_000
):
    """
    Initialize codebook using k-means clustering on real data.

    Args:
        vq_model: VectorQuantizer model
        backbone: Segmentation backbone
        train_loader: Training data loader
        device: Device to run on
        max_samples: Maximum number of samples for k-means
    """
    print("Collecting features for k-means initialization...")
    all_features = []

    backbone.eval()
    with torch.no_grad():
        i = 0
        for images, _ in train_loader:
            print(f"Processing batch {i}")
            features = backbone.extract_features(images.to(device))
            i += 1

            B, C, H, W = features.shape
            feat_flat = features.permute(0, 2, 3, 1).reshape(-1, C)
            all_features.append(feat_flat.cpu())

            if len(all_features) * feat_flat.shape[0] > max_samples:
                break

    all_features = torch.cat(all_features).numpy()

    print(f"Running k-means on {all_features.shape[0]} features...")
    kmeans = MiniBatchKMeans(
        n_clusters=vq_model.num_vectors,
        random_state=0,
        batch_size=1000,
        max_iter=100
    )
    kmeans.fit(all_features)

    # Update codebook with cluster centers
    vq_model.codebook.embeddings.data = torch.tensor(kmeans.cluster_centers_).to(device).float()

    print(f"Codebook initialized:")
    print(f"  Mean: {vq_model.codebook.embeddings.mean():.2f}")
    print(f"  Std: {vq_model.codebook.embeddings.std():.2f}")
    print(f"  Norm: {vq_model.codebook.embeddings.norm(dim=1).mean():.2f}")

visualize

Visualization script for comparing VQ vs baseline segmentation results.

Usage

python visualize.py --vq_checkpoint ./outputs/vq_squeeze/version_0/last.ckpt --baseline_checkpoint ./outputs/baseline_segmentation_baseline/version_0/last.ckpt

Functions

create_backbone
create_backbone(config)

Create segmentation backbone based on config.

Source code in embeddings_squeeze\visualize.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def create_backbone(config):
    """Create segmentation backbone based on config."""
    if config.model.backbone.lower() == "vit":
        backbone = ViTSegmentationBackbone(
            num_classes=config.model.num_classes,
            freeze_backbone=config.model.freeze_backbone
        )
    elif config.model.backbone.lower() == "deeplab":
        backbone = DeepLabV3SegmentationBackbone(
            weights_name=config.model.deeplab_weights,
            num_classes=config.model.num_classes,
            freeze_backbone=config.model.freeze_backbone
        )
    else:
        raise ValueError(f"Unknown backbone: {config.model.backbone}")

    return backbone
create_data_module
create_data_module(config)

Create data module based on config.

Source code in embeddings_squeeze\visualize.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def create_data_module(config):
    """Create data module based on config."""
    if config.data.dataset.lower() == "oxford_pet":
        data_module = OxfordPetDataModule(
            data_path=config.data.data_path,
            batch_size=config.training.batch_size,
            num_workers=config.training.num_workers,
            pin_memory=config.training.pin_memory,
            image_size=config.data.image_size,
            subset_size=config.data.subset_size
        )
    else:
        raise ValueError(f"Unknown dataset: {config.data.dataset}")

    return data_module
load_models
load_models(
    vq_checkpoint_path,
    baseline_checkpoint_path,
    config,
    device,
)

Load VQ and baseline models from checkpoints.

Parameters:

Name Type Description Default
vq_checkpoint_path str

Path to VQ model checkpoint

required
baseline_checkpoint_path str

Path to baseline model checkpoint

required
config

Configuration object

required
device

Device to load models on

required

Returns:

Name Type Description
vq_model

Loaded VQ model

baseline_model

Loaded baseline model

Source code in embeddings_squeeze\visualize.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def load_models(vq_checkpoint_path: str, baseline_checkpoint_path: str, config, device):
    """
    Load VQ and baseline models from checkpoints.

    Args:
        vq_checkpoint_path: Path to VQ model checkpoint
        baseline_checkpoint_path: Path to baseline model checkpoint
        config: Configuration object
        device: Device to load models on

    Returns:
        vq_model: Loaded VQ model
        baseline_model: Loaded baseline model
    """
    print(f"Loading VQ model from: {vq_checkpoint_path}")
    vq_model = VQSqueezeModule.load_from_checkpoint(vq_checkpoint_path)
    vq_model.to(device)
    vq_model.eval()

    print(f"Loading baseline model from: {baseline_checkpoint_path}")
    baseline_model = BaselineSegmentationModule.load_from_checkpoint(baseline_checkpoint_path)
    baseline_model.to(device)
    baseline_model.eval()

    return vq_model, baseline_model
main
main()

Main visualization function.

Source code in embeddings_squeeze\visualize.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def main():
    """Main visualization function."""
    parser = argparse.ArgumentParser(description="VQ vs Baseline Segmentation Visualization")

    # Required checkpoint paths
    parser.add_argument("--vq_checkpoint", type=str, required=True,
                       help="Path to VQ model checkpoint")
    parser.add_argument("--baseline_checkpoint", type=str, required=True,
                       help="Path to baseline model checkpoint")

    # Dataset arguments
    parser.add_argument("--dataset_split", type=str, default="test",
                       choices=["test", "val"], help="Dataset split to use")
    parser.add_argument("--dataset", type=str, default="oxford_pet",
                       choices=["oxford_pet"], help="Dataset name")
    parser.add_argument("--data_path", type=str, default="./data",
                       help="Path to dataset")
    parser.add_argument("--subset_size", type=int, default=None,
                       help="Subset size for quick testing")

    # Visualization arguments
    parser.add_argument("--n_best", type=int, default=5,
                       help="Number of best results to show")
    parser.add_argument("--n_worst", type=int, default=5,
                       help="Number of worst results to show")
    parser.add_argument("--output_dir", type=str, default="./visualizations",
                       help="Output directory for visualization figures")

    # Model arguments (should match training config)
    parser.add_argument("--model", type=str, default="vit",
                       choices=["vit", "deeplab"], help="Backbone model")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")

    args = parser.parse_args()

    # Create configuration
    config = get_default_config()
    config = update_config_from_args(config, vars(args))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Load models
    vq_model, baseline_model = load_models(
        args.vq_checkpoint, 
        args.baseline_checkpoint, 
        config, 
        device
    )

    # Create data module
    data_module = create_data_module(config)
    data_module.setup(args.dataset_split)

    # Get appropriate dataloader
    if args.dataset_split == "test":
        dataloader = data_module.test_dataloader()
    else:
        dataloader = data_module.val_dataloader()

    print(f"Evaluating on {args.dataset_split} split with {len(dataloader.dataset)} samples")

    # Prepare visualization data
    best_samples, worst_samples = prepare_visualization_data(
        vq_model=vq_model,
        baseline_model=baseline_model,
        dataloader=dataloader,
        device=device,
        num_classes=config.model.num_classes,
        n_best=args.n_best,
        n_worst=args.n_worst
    )

    # Create visualizations
    print(f"Creating visualization for {len(best_samples)} best samples...")
    best_output_path = os.path.join(args.output_dir, f"best_results_{args.dataset_split}.png")
    visualize_comparison(
        samples=best_samples,
        title=f"Best VQ Results ({args.dataset_split} split)",
        output_path=best_output_path,
        num_classes=config.model.num_classes
    )

    print(f"Creating visualization for {len(worst_samples)} worst samples...")
    worst_output_path = os.path.join(args.output_dir, f"worst_results_{args.dataset_split}.png")
    visualize_comparison(
        samples=worst_samples,
        title=f"Worst VQ Results ({args.dataset_split} split)",
        output_path=worst_output_path,
        num_classes=config.model.num_classes
    )

    # Print summary statistics
    best_ious = [sample[1] for sample in best_samples]
    worst_ious = [sample[1] for sample in worst_samples]

    print("\n" + "="*60)
    print("VISUALIZATION SUMMARY")
    print("="*60)
    print(f"Dataset split: {args.dataset_split}")
    print(f"Model backbone: {config.model.backbone}")
    print(f"Number of classes: {config.model.num_classes}")
    print()
    print(f"Best VQ IoU scores: {[f'{iou:.3f}' for iou in best_ious]}")
    print(f"Worst VQ IoU scores: {[f'{iou:.3f}' for iou in worst_ious]}")
    print(f"Best IoU range: {min(best_ious):.3f} - {max(best_ious):.3f}")
    print(f"Worst IoU range: {min(worst_ious):.3f} - {max(worst_ious):.3f}")
    print()
    print(f"Visualizations saved to:")
    print(f"  Best results: {best_output_path}")
    print(f"  Worst results: {worst_output_path}")
    print("="*60)