Source code for vision_unlearning.datasets.imagenette

from typing import ClassVar
import os
import requests
import tarfile
import numpy as np
from torchvision import datasets, transforms
from vision_unlearning.datasets.base import UnlearnDataset, UnlearnDatasetSplit


[docs] class UnlearnDatasetImagenette(UnlearnDataset): download_path: str # New parameter for temporary download folder class_mapping: ClassVar[dict] = { # ID to human readable "n01440764": "tench", "n02102040": "english_springer", "n02979186": "cassette_player", "n03000684": "chain_saw", "n03028079": "church", "n03394916": "french_horn", "n03417042": "garbage_truck", "n03425413": "gas_pump", "n03445777": "golf_ball", "n03888257": "parachute", }
[docs] def _download_imagenette(self, temp_download_path: str) -> None: ''' Download Imagenette from FastAI's S3 bucket to a temporary folder. Skip download if files already exist. ''' os.makedirs(temp_download_path, exist_ok=True) url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz" tar_path = os.path.join(temp_download_path, "imagenette2-320.tgz") extracted_path = os.path.join(temp_download_path, "imagenette2-320") # Check if the dataset is already downloaded and extracted if os.path.exists(extracted_path) and os.path.isdir(extracted_path): print("Dataset already downloaded and extracted. Skipping download.") return # Download the dataset if not already downloaded if not os.path.exists(tar_path): print("Downloading Imagenette...") response = requests.get(url, stream=True) with open(tar_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Download complete.") # Extract the dataset print("Extracting dataset...") with tarfile.open(tar_path, 'r:gz') as tar: tar.extractall(path=temp_download_path) print("Extraction complete.")
[docs] def _load(self) -> None: # Download Imagenette to the temporary folder self._download_imagenette(self.download_path) # Load the dataset from the extracted folder extracted_path = os.path.join(self.download_path, "imagenette2-320") train_path = os.path.join(extracted_path, "train") val_path = os.path.join(extracted_path, "val") # Define the transform self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(self.mean, self.std) ]) # Load the dataset using ImageFolder train_set = datasets.ImageFolder(train_path, transform=transform) test_set = datasets.ImageFolder(val_path, transform=transform) self._classes = train_set.classes assert self._classes is not None self._n_classes = len(self._classes) # Split the training set into training and validation rng = np.random.RandomState(42) val_idxs = [] for i in range(self._n_classes): class_idx = np.where(np.array(train_set.targets) == i)[0] val_idxs.append(rng.choice(class_idx, int(0.1 * len(class_idx)), replace=False)) val_idxs_stacked = np.hstack(val_idxs) train_idxs = list(set(range(len(train_set))) - set(val_idxs_stacked)) # Create validation set as a new ImageFolder instance valid_data = [train_set.samples[i] for i in val_idxs_stacked] valid_targets = [train_set.targets[i] for i in val_idxs_stacked] valid_set = datasets.ImageFolder( train_path, transform=transform, loader=lambda x: train_set.loader(x), # Use the same loader as the training set ) valid_set.samples = valid_data valid_set.targets = valid_targets valid_set.imgs = valid_data # Update the internal list of images # Update training set to exclude validation samples train_set.samples = [train_set.samples[i] for i in train_idxs] train_set.targets = [train_set.targets[i] for i in train_idxs] train_set.imgs = train_set.samples # Update the internal list of images self._dataset_splits = { UnlearnDatasetSplit.Train: train_set, UnlearnDatasetSplit.Validation: valid_set, UnlearnDatasetSplit.Test: test_set, }
[docs] def make_prompt_for_label(self, label: int) -> str: assert self._classes is not None return f"an image of {self.class_mapping[self._classes[label]]}"