Source code for vision_unlearning.datasets.base

import os
import pickle
import json
from enum import Enum
from typing import Dict, List, Optional, Union, Literal, Sequence
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets.vision import VisionDataset
from torch.utils.data import Subset, DataLoader
from vision_unlearning.utils.logger import get_logger


logger = get_logger('datasets')


[docs] class UnlearnDatasetSplit(Enum): Train = "train" Validation = "validation" Test = "test" Train_retain = "train_retain" Train_retain_MIA = "train_retain_mia" Train_forget = "train_forget" Test_retain = "test_retain" Test_forget = "test_forget" Validation_retain = "validation_retain" Validation_forget = "validation_forget"
[docs] class UnlearnDatasetSplitMode(Enum): Class = "class" Random = "random" Temporal = "temporal"
[docs] class SplitNotAvailableError(Exception): pass
[docs] class UnlearnDataset(BaseModel, ABC): ''' Wrapper around huggingface datasets Organize the forget-retain splits ''' model_config = ConfigDict(arbitrary_types_allowed=True) split_mode: UnlearnDatasetSplitMode split_kwargs: dict = {} # Should contain the required by the mode-specific downstream methods (_split_class, _split_random, _split_temporal) _dataset_splits: Dict[UnlearnDatasetSplit, Union[Subset, VisionDataset]] = {} _classes: Optional[List[str]] = None _n_classes: int = 0 mean: Optional[Sequence[float]] = None std: Optional[Sequence[float]] = None
[docs] def model_post_init(self, __context: Optional[dict]) -> None: # TODO: using pydantic's model_post_init makes this hard to debug... maybe just overwritting the constructor is better self._load() self._split() assert set([e.value for e in UnlearnDatasetSplit]) == set([key.value for key in list(self._dataset_splits.keys())]), "All possible splits should be filled" pass
[docs] @abstractmethod def _load(self) -> None: ''' Load the dataset from disk or download it. Side effects: updates the properties _dataset_splits, _classes, _n_classes ''' pass
[docs] def _split(self) -> None: ''' Split the dataset based on the specified mode. Side effects: updates the property dataset_splits Raised exceptions: none ''' if self.split_mode == UnlearnDatasetSplitMode.Class: self._split_class(**self.split_kwargs) elif self.split_mode == UnlearnDatasetSplitMode.Random: self._split_random(**self.split_kwargs) elif self.split_mode == UnlearnDatasetSplitMode.Temporal: self._split_temporal(**self.split_kwargs) else: raise NotImplementedError(f"Split mode {self.split_mode} not implemented")
[docs] def _split_class(self, forget: List[str | int] | str | int) -> None: assert self._classes is not None, "Classes should be loaded before splitting" c: List[int] if isinstance(forget, list): if all([isinstance(e, str) for e in forget]): c = [i for i, class_name in enumerate(self._classes) if class_name in forget] elif all([isinstance(e, int) for e in forget]): assert type(self._n_classes) == int, "Number of classes should be loaded before splitting" # noqa assert all([type(e) == int for e in forget]) # noqa max_index = np.max(forget) assert type(max_index) == int, "Forget should be a list of strings or integers" # noqa assert max_index < self._n_classes, "Forget should be a list of strings or integers" # noqa c = forget # type: ignore else: raise ValueError("Forget should be a list of strings or integers") elif isinstance(forget, str): c = [i for i, class_name in enumerate(self._classes) if class_name == forget] elif isinstance(forget, int): c = [forget] if forget < self._n_classes else None # type: ignore else: raise ValueError("Forget should be a list of strings or integers") assert type(c) == list # noqa assert all([type(e) == int for e in c]) # noqa assert len(c) > 0, "Forget should be a list integers" assert isinstance(self._dataset_splits[UnlearnDatasetSplit.Train], VisionDataset), "Train should be a VisionDataset" assert isinstance(self._dataset_splits[UnlearnDatasetSplit.Validation], VisionDataset), "Valid should be a VisionDataset" assert isinstance(self._dataset_splits[UnlearnDatasetSplit.Test], VisionDataset), "Test should be a VisionDataset" trainf_mask = np.isin(np.array(self._dataset_splits[UnlearnDatasetSplit.Train].targets), c) # type: ignore validf_mask = np.isin(np.array(self._dataset_splits[UnlearnDatasetSplit.Validation].targets), c) # type: ignore testf_mask = np.isin(np.array(self._dataset_splits[UnlearnDatasetSplit.Test].targets), c) # type: ignore train_idx = np.array(range(len(self._dataset_splits[UnlearnDatasetSplit.Train]))) valid_idx = np.array(range(len(self._dataset_splits[UnlearnDatasetSplit.Validation]))) test_idx = np.array(range(len(self._dataset_splits[UnlearnDatasetSplit.Test]))) train_f_idx = train_idx[trainf_mask] train_r_idx = train_idx[~trainf_mask] valid_f_idx = valid_idx[validf_mask] valid_r_idx = valid_idx[~validf_mask] test_f_idx = test_idx[testf_mask] test_r_idx = test_idx[~testf_mask] idxs_mia = np.random.choice(train_r_idx, len(test_r_idx), replace=False) new_splits = { UnlearnDatasetSplit.Train_retain: Subset(self._dataset_splits[UnlearnDatasetSplit.Train], list(train_r_idx)), UnlearnDatasetSplit.Train_retain_MIA: Subset(self._dataset_splits[UnlearnDatasetSplit.Train], list(idxs_mia)), UnlearnDatasetSplit.Train_forget: Subset(self._dataset_splits[UnlearnDatasetSplit.Train], list(train_f_idx)), UnlearnDatasetSplit.Validation_retain: Subset(self._dataset_splits[UnlearnDatasetSplit.Validation], list(valid_r_idx)), UnlearnDatasetSplit.Validation_forget: Subset(self._dataset_splits[UnlearnDatasetSplit.Validation], list(valid_f_idx)), UnlearnDatasetSplit.Test_retain: Subset(self._dataset_splits[UnlearnDatasetSplit.Test], list(test_r_idx)), UnlearnDatasetSplit.Test_forget: Subset(self._dataset_splits[UnlearnDatasetSplit.Test], list(test_f_idx)), } self._dataset_splits.update(new_splits)
[docs] def _split_random(self, n_forget: int, seed: int = 42) -> None: raise NotImplementedError("Random split not implemented")
[docs] def _split_temporal(self, n_forget: int) -> None: raise NotImplementedError("Temporal split not implemented")
[docs] def get_loader(self, split: UnlearnDatasetSplit, batchsize: int, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = True) -> Optional[DataLoader]: ''' Return this split for this dataset. Side effects: none Raised exceptions: SplitNotAvailableError, if the requested split is not available ''' if split not in self._dataset_splits: raise SplitNotAvailableError(f"Split {split} not available") return DataLoader(self._dataset_splits[split], batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory)
[docs] def get_splits(self) -> Dict[UnlearnDatasetSplit, Union[Subset, VisionDataset]]: ''' Return the available splits. Side effects: none Raised exceptions: none ''' return self._dataset_splits
[docs] def denormalize(self, normalized: torch.Tensor) -> torch.Tensor: return normalized * torch.Tensor(self.std).view(-1, 1, 1) + torch.Tensor(self.mean).view(-1, 1, 1)
[docs] def save(self, path: str, format: Literal['pkl', 'jpg'] = 'pkl', save_unsplit: bool = False) -> None: ''' Save each split to disk. Side effects: saves files to disk Raised exceptions: OS-related errors ''' assert self._classes is not None, "Classes should be loaded before saving" os.makedirs(path, exist_ok=True) for split, data in self._dataset_splits.items(): if save_unsplit or (split != UnlearnDatasetSplit.Train and split != UnlearnDatasetSplit.Validation and split != UnlearnDatasetSplit.Test): if format == 'pkl': with open(os.path.join(path, f"{split.value}.pkl"), 'wb') as f: pickle.dump(data, f) elif format == 'jpg': split_path = os.path.join(path, split.value) os.makedirs(split_path, exist_ok=True) metadata = [] for idx in range(len(data)): image, label = data[idx] assert type(label) == int # noqa image_path = os.path.join(split_path, f"{idx}.jpg") # Convert tensor to PIL image and save if isinstance(image, torch.Tensor): if self.mean is not None or self.std is not None: image = self.denormalize(image) image = transforms.ToPILImage()(image) image.save(image_path) metadata.append({ "file_name": f"{idx}.jpg", "text": self.make_prompt_for_label(label) }) with open(os.path.join(split_path, "metadata.jsonl"), 'w') as f: for entry in metadata: f.write(json.dumps(entry) + "\n") else: raise ValueError(f"Format {format} not supported")
[docs] def make_prompt_for_label(self, label: int) -> str: assert self._classes is not None return f"an image of {self._classes[label]}"