Skip to content

Документация утилит

utils

Utility functions for VQ compression.

Functions

measure_compression

measure_compression(
    vq_model, backbone, test_loader, device
)

Measure compression ratio achieved by VQ.

Parameters:

Name Type Description Default
vq_model

VectorQuantizer model

required
backbone

Segmentation backbone

required
test_loader

Test data loader

required
device

Device to run on

required

Returns:

Name Type Description
compression_ratio

Compression ratio achieved

Source code in embeddings_squeeze\utils\compression.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def measure_compression(vq_model, backbone, test_loader, device):
    """
    Measure compression ratio achieved by VQ.

    Args:
        vq_model: VectorQuantizer model
        backbone: Segmentation backbone
        test_loader: Test data loader
        device: Device to run on

    Returns:
        compression_ratio: Compression ratio achieved
    """
    vq_model.eval()

    total_original_bits = 0
    total_compressed_bits = 0
    num_features = 0

    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            features = backbone.extract_features(images)

            B, C, H, W = features.shape
            feat_flat = features.permute(0, 2, 3, 1).reshape(-1, C)

            # Original size (float32 = 32 bits)
            original_bits = feat_flat.numel() * 32

            # Compressed size (only indices)
            num_codes = vq_model.num_vectors
            bits_per_index = np.ceil(np.log2(num_codes))
            compressed_bits = feat_flat.shape[0] * bits_per_index

            total_original_bits += original_bits
            total_compressed_bits += compressed_bits
            num_features += feat_flat.shape[0]

    compression_ratio = total_original_bits / total_compressed_bits

    print("="*70)
    print("COMPRESSION ANALYSIS")
    print("="*70)
    print(f"Total features processed: {num_features}")
    print(f"Feature dimension: {vq_model.vector_dim}")
    print(f"Codebook size: {vq_model.num_vectors}")
    print()
    print(f"Original storage:")
    print(f"  Per feature: {vq_model.vector_dim} × 32 bits = {vq_model.vector_dim * 32} bits = {vq_model.vector_dim * 4} bytes")
    print(f"  Total: {total_original_bits / 8 / 1024 / 1024:.2f} MB")
    print()
    print(f"Compressed storage:")
    print(f"  Per feature: {bits_per_index:.0f} bits (index)")
    print(f"  Total: {total_compressed_bits / 8 / 1024:.2f} KB")
    print(f"  + Codebook: {vq_model.num_vectors * vq_model.vector_dim * 4 / 1024:.2f} KB")
    print()
    print(f"Compression ratio: {compression_ratio:.1f}x")
    print(f"Space savings: {(1 - 1/compression_ratio)*100:.1f}%")
    print("="*70)

    return compression_ratio

compute_iou_metrics

compute_iou_metrics(
    predictions, targets, num_classes, ignore_index=255
)

Compute IoU metrics for segmentation.

Parameters:

Name Type Description Default
predictions

Predicted masks [B, H, W]

required
targets

Ground truth masks [B, H, W]

required
num_classes

Number of classes

required
ignore_index

Index to ignore in computation

255

Returns:

Name Type Description
metrics

Dictionary with IoU metrics

Source code in embeddings_squeeze\utils\compression.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def compute_iou_metrics(predictions, targets, num_classes, ignore_index=255):
    """
    Compute IoU metrics for segmentation.

    Args:
        predictions: Predicted masks [B, H, W]
        targets: Ground truth masks [B, H, W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation

    Returns:
        metrics: Dictionary with IoU metrics
    """
    predictions = predictions.cpu().numpy()
    targets = targets.cpu().numpy()

    # Flatten arrays
    pred_flat = predictions.flatten()
    target_flat = targets.flatten()

    # Remove ignored pixels
    valid_mask = target_flat != ignore_index
    pred_flat = pred_flat[valid_mask]
    target_flat = target_flat[valid_mask]

    # Compute IoU for each class
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred_flat == cls)
        target_cls = (target_flat == cls)

        intersection = (pred_cls & target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection

        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0

        ious.append(iou)

    # Compute mean IoU
    mean_iou = np.mean(ious)

    # Compute pixel accuracy
    pixel_acc = (pred_flat == target_flat).mean()

    metrics = {
        'mean_iou': mean_iou,
        'pixel_accuracy': pixel_acc,
        'class_ious': ious
    }

    return metrics

initialize_codebook_from_data

initialize_codebook_from_data(
    vq_model,
    backbone,
    train_loader,
    device,
    max_samples=50000,
)

Initialize codebook using k-means clustering on real data.

Parameters:

Name Type Description Default
vq_model

VectorQuantizer model

required
backbone

Segmentation backbone

required
train_loader

Training data loader

required
device

Device to run on

required
max_samples int

Maximum number of samples for k-means

50000
Source code in embeddings_squeeze\utils\initialization.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def initialize_codebook_from_data(
    vq_model, 
    backbone, 
    train_loader, 
    device, 
    max_samples: int = 50_000
):
    """
    Initialize codebook using k-means clustering on real data.

    Args:
        vq_model: VectorQuantizer model
        backbone: Segmentation backbone
        train_loader: Training data loader
        device: Device to run on
        max_samples: Maximum number of samples for k-means
    """
    print("Collecting features for k-means initialization...")
    all_features = []

    backbone.eval()
    with torch.no_grad():
        i = 0
        for images, _ in train_loader:
            print(f"Processing batch {i}")
            features = backbone.extract_features(images.to(device))
            i += 1

            B, C, H, W = features.shape
            feat_flat = features.permute(0, 2, 3, 1).reshape(-1, C)
            all_features.append(feat_flat.cpu())

            if len(all_features) * feat_flat.shape[0] > max_samples:
                break

    all_features = torch.cat(all_features).numpy()

    print(f"Running k-means on {all_features.shape[0]} features...")
    kmeans = MiniBatchKMeans(
        n_clusters=vq_model.num_vectors,
        random_state=0,
        batch_size=1000,
        max_iter=100
    )
    kmeans.fit(all_features)

    # Update codebook with cluster centers
    vq_model.codebook.embeddings.data = torch.tensor(kmeans.cluster_centers_).to(device).float()

    print(f"Codebook initialized:")
    print(f"  Mean: {vq_model.codebook.embeddings.mean():.2f}")
    print(f"  Std: {vq_model.codebook.embeddings.std():.2f}")
    print(f"  Norm: {vq_model.codebook.embeddings.norm(dim=1).mean():.2f}")

compute_sample_iou

compute_sample_iou(
    prediction, target, num_classes, ignore_index=255
)

Compute IoU for a single sample.

Parameters:

Name Type Description Default
prediction Tensor

Predicted mask [H, W]

required
target Tensor

Ground truth mask [H, W]

required
num_classes int

Number of classes

required
ignore_index int

Index to ignore in computation

255

Returns:

Name Type Description
iou float

Mean IoU across all classes

Source code in embeddings_squeeze\utils\comparison.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def compute_sample_iou(prediction: torch.Tensor, target: torch.Tensor, num_classes: int, ignore_index: int = 255) -> float:
    """
    Compute IoU for a single sample.

    Args:
        prediction: Predicted mask [H, W]
        target: Ground truth mask [H, W]
        num_classes: Number of classes
        ignore_index: Index to ignore in computation

    Returns:
        iou: Mean IoU across all classes
    """
    prediction = prediction.cpu().numpy()
    target = target.cpu().numpy()

    # Flatten arrays
    pred_flat = prediction.flatten()
    target_flat = target.flatten()

    # Remove ignored pixels
    valid_mask = target_flat != ignore_index
    pred_flat = pred_flat[valid_mask]
    target_flat = target_flat[valid_mask]

    # Compute IoU for each class
    ious = []
    for cls in range(num_classes):
        pred_cls = (pred_flat == cls)
        target_cls = (target_flat == cls)

        intersection = (pred_cls & target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection

        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0

        ious.append(iou)

    # Return mean IoU
    return np.mean(ious)

evaluate_model

evaluate_model(model, dataloader, device, num_classes=21)

Evaluate model on dataset and collect results.

Parameters:

Name Type Description Default
model

Model to evaluate

required
dataloader

Data loader

required
device

Device to run on

required
num_classes int

Number of classes

21

Returns:

Name Type Description
results List[Tuple[int, float, Tensor, Tensor, Tensor]]

List of (sample_idx, iou, image, mask, prediction) tuples

Source code in embeddings_squeeze\utils\comparison.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def evaluate_model(model, dataloader, device, num_classes: int = 21) -> List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """
    Evaluate model on dataset and collect results.

    Args:
        model: Model to evaluate
        dataloader: Data loader
        device: Device to run on
        num_classes: Number of classes

    Returns:
        results: List of (sample_idx, iou, image, mask, prediction) tuples
    """
    model.eval()
    results = []

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.squeeze(1).long().to(device)

            # Get predictions
            if hasattr(model, 'predict'):
                predictions = model.predict(images)
            else:
                # For VQ model
                predictions = model.predict_with_vq(images)

            # Process each sample in batch
            for i in range(images.shape[0]):
                image = images[i]
                mask = masks[i]
                pred = predictions[i]

                # Compute IoU
                iou = compute_sample_iou(pred, mask, num_classes)

                # Store results
                sample_idx = batch_idx * dataloader.batch_size + i
                results.append((sample_idx, iou, image, mask, pred))

    return results

find_best_worst_samples

find_best_worst_samples(results, n_best=5, n_worst=5)

Find best and worst samples based on IoU.

Parameters:

Name Type Description Default
results List[Tuple[int, float, Tensor, Tensor, Tensor]]

List of (sample_idx, iou, image, mask, prediction) tuples

required
n_best int

Number of best samples to return

5
n_worst int

Number of worst samples to return

5

Returns:

Name Type Description
best_samples List

List of best sample tuples

worst_samples List

List of worst sample tuples

Source code in embeddings_squeeze\utils\comparison.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def find_best_worst_samples(results: List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor]], 
                           n_best: int = 5, n_worst: int = 5) -> Tuple[List, List]:
    """
    Find best and worst samples based on IoU.

    Args:
        results: List of (sample_idx, iou, image, mask, prediction) tuples
        n_best: Number of best samples to return
        n_worst: Number of worst samples to return

    Returns:
        best_samples: List of best sample tuples
        worst_samples: List of worst sample tuples
    """
    # Sort by IoU (descending)
    sorted_results = sorted(results, key=lambda x: x[1], reverse=True)

    # Get best and worst
    best_samples = sorted_results[:n_best]
    worst_samples = sorted_results[-n_worst:]

    return best_samples, worst_samples

prepare_visualization_data

prepare_visualization_data(
    vq_model,
    baseline_model,
    dataloader,
    device,
    num_classes=21,
    n_best=5,
    n_worst=5,
)

Prepare data for visualization by running both models and ranking results.

Parameters:

Name Type Description Default
vq_model

VQ model

required
baseline_model

Baseline model

required
dataloader

Data loader

required
device

Device to run on

required
num_classes int

Number of classes

21
n_best int

Number of best samples

5
n_worst int

Number of worst samples

5

Returns:

Name Type Description
best_samples

List of best sample tuples with both predictions

worst_samples

List of worst sample tuples with both predictions

Source code in embeddings_squeeze\utils\comparison.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def prepare_visualization_data(vq_model, baseline_model, dataloader, device, 
                              num_classes: int = 21, n_best: int = 5, n_worst: int = 5):
    """
    Prepare data for visualization by running both models and ranking results.

    Args:
        vq_model: VQ model
        baseline_model: Baseline model
        dataloader: Data loader
        device: Device to run on
        num_classes: Number of classes
        n_best: Number of best samples
        n_worst: Number of worst samples

    Returns:
        best_samples: List of best sample tuples with both predictions
        worst_samples: List of worst sample tuples with both predictions
    """
    # Evaluate VQ model
    print("Evaluating VQ model...")
    vq_results = evaluate_model(vq_model, dataloader, device, num_classes)

    # Evaluate baseline model
    print("Evaluating baseline model...")
    baseline_results = evaluate_model(baseline_model, dataloader, device, num_classes)

    # Combine results (assuming same order)
    combined_results = []
    for (idx1, iou1, img1, mask1, pred_vq), (idx2, iou2, img2, mask2, pred_baseline) in zip(vq_results, baseline_results):
        assert idx1 == idx2, "Sample indices don't match"
        assert torch.equal(img1, img2), "Images don't match"
        assert torch.equal(mask1, mask2), "Masks don't match"

        combined_results.append((idx1, iou1, img1, mask1, pred_baseline, pred_vq))

    # Find best and worst based on VQ IoU
    best_samples, worst_samples = find_best_worst_samples(combined_results, n_best, n_worst)

    return best_samples, worst_samples

visualize_comparison

visualize_comparison(
    samples, title, output_path, num_classes=21
)

Create visualization comparing baseline and VQ predictions.

Parameters:

Name Type Description Default
samples List[Tuple[int, float, Tensor, Tensor, Tensor, Tensor]]

List of (idx, iou, image, mask, pred_baseline, pred_vq) tuples

required
title str

Figure title

required
output_path str

Path to save figure

required
num_classes int

Number of classes

21
Source code in embeddings_squeeze\utils\comparison.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def visualize_comparison(samples: List[Tuple[int, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], 
                        title: str, output_path: str, num_classes: int = 21):
    """
    Create visualization comparing baseline and VQ predictions.

    Args:
        samples: List of (idx, iou, image, mask, pred_baseline, pred_vq) tuples
        title: Figure title
        output_path: Path to save figure
        num_classes: Number of classes
    """
    n_samples = len(samples)
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))

    if n_samples == 1:
        axes = axes.reshape(1, -1)

    colormap = create_segmentation_colormap(num_classes)

    for i, (idx, iou, image, mask, pred_baseline, pred_vq) in enumerate(samples):
        # Denormalize image
        image_vis = denormalize_image(image)

        # Convert to numpy for plotting
        image_np = image_vis.permute(1, 2, 0).cpu().numpy()
        mask_np = mask.cpu().numpy()
        pred_baseline_np = pred_baseline.cpu().numpy()
        pred_vq_np = pred_vq.cpu().numpy()

        # Plot original image
        axes[i, 0].imshow(image_np)
        axes[i, 0].set_title(f"Original Image\nSample {idx}")
        axes[i, 0].axis('off')

        # Plot ground truth
        axes[i, 1].imshow(mask_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')

        # Plot baseline prediction
        axes[i, 2].imshow(pred_baseline_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 2].set_title("Baseline Prediction")
        axes[i, 2].axis('off')

        # Plot VQ prediction
        axes[i, 3].imshow(pred_vq_np, cmap=colormap, vmin=0, vmax=num_classes-1)
        axes[i, 3].set_title(f"VQ Prediction\nIoU: {iou:.3f}")
        axes[i, 3].axis('off')

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()

    print(f"Visualization saved to: {output_path}")