Source code for vision_unlearning.metrics.image

from abc import ABC, abstractmethod
from typing import Union, Optional, Any, Dict, List, Literal
import tempfile
import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage
import torch
from transformers import (
    pipeline,
    AutoImageProcessor,
    SiglipForImageClassification,
)
from transformers.pipelines.image_classification import ImageClassificationPipeline
import piq

from vision_unlearning.metrics.base import Metric


# TODO take these pseudo tests and examples and transform into automated test


[docs] class MetricImage(Metric, ABC): ''' Based only on the image itself e.g., image quality, painting style '''
[docs] @abstractmethod def score(self, image: Image.Image) -> Dict[str, Any]: pass
[docs] class MetricPaintingStyle(MetricImage): metrics: List[Literal['is_desired_style', 'desired_style_confidence']] = [] # TODO: this is currently ignored desired_style: str top_k: int = 5 model_path: str device: Optional[torch.device | str | int] = 'cuda' _pipeline: Optional[ImageClassificationPipeline] = None
[docs] def model_post_init(self, __context: Optional[dict] = None) -> None: self._pipeline = pipeline('image-classification', model=self.model_path, device=self.device)
[docs] def score(self, image: Image.Image) -> Dict[str, bool | float]: assert self._pipeline is not None scores = { 'is_desired_style': False, 'desired_style_confidence': 0.0 } predictions: list = self._pipeline(image, top_k=self.top_k) for p in predictions: if p['label'] == self.desired_style: scores['is_desired_style'] = True scores['desired_style_confidence'] = float(p['score']) return scores
# Pseudo test # import torch # from PIL import Image # # #image = Image.open('assets/Diffusion-MU-Attack/files/dataset/vangogh/imgs/35_0.png') # image = Image.open('assets/Diffusion-MU-Attack/files/dataset/i2p_nude/imgs/1011_0.png') # device = 'cuda' if torch.cuda.is_available() else 'cpu' # metric_painting_style = MetricPaintingStyle(desired_style='vincent-van-gogh', top_k=3, model_path='assets/models_pretrained/style_classifier/results/checkpoint-2800', device=device) # result = metric_painting_style.score(image) # print(result)
[docs] class MetricRace(MetricImage): """ Race classification using Hugging Face model: syntheticbot/clip-face-attribute-classifier Requires the following additional dependencies: * tf_keras = "~2.19.0" * tensorrt = "~10.13.2" * blinker = "~1.9.0" """ # TODO: if we could do this with a HF model model be better, no need for additional libs
[docs] def model_post_init(self, __context: Optional[dict] = None) -> None: try: from deepface import DeepFace # noqa self.DeepFace = DeepFace except ImportError as e: raise ImportError("DeepFace library is required for MetricRace. Please install it via 'pip install deepface'. Recommended version: deepfaces = '~0.0.95', tf_keras = '~2.19.0', tensorrt = '~10.13.2'") from e
[docs] def score(self, image: Image.Image) -> Dict[str, str]: results = self.DeepFace.analyze( np.array(image.convert('RGB')), actions=['race'], enforce_detection=False, ) # DeepFace may return list if multiple faces if isinstance(results, list): results = results[0] return { "race": results.get("dominant_race"), }
# Example usage ''' img = Image.open("assets/datasets/lfw_splits/George_W_Bush/train_forget/George_W_Bush_0001.jpg") metric_race = MetricRace() print(metric_race.score(img)) '''
[docs] class MetricGender(MetricImage): device: Optional[torch.device | str | int] = 'cpu' _model_name: str = "prithivMLmods/Realistic-Gender-Classification" _id2label = {0: 'female', 1: 'male'} _model: Any _processor: Any
[docs] def model_post_init(self, __context): # load processor & model from HF self._processor = AutoImageProcessor.from_pretrained(self._model_name) self._model = SiglipForImageClassification.from_pretrained(self._model_name) self._model.to(self.device).eval()
[docs] def score(self, image: Image.Image) -> Dict[str, Union[Literal['male', 'female'], float]]: # ensure RGB and prepare batch img = image.convert("RGB") inputs = self._processor(images=img, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self._model(**inputs) logits = outputs.logits[0] probs = torch.softmax(logits, dim=-1) idx = int(torch.argmax(probs)) label: Literal['male', 'female'] = self._id2label[idx] # type: ignore confidence: float = float(probs[idx]) return { 'gender': label, 'gender_confidence': confidence }
# This is how to use it ''' import torch from PIL import Image image = Image.open('assets/male.jpg') #image = Image.open('assets/female.jpg') device = 'cuda' if torch.cuda.is_available() else 'cpu' metric_gender = MetricGender(device=device) result = metric_gender.score(image) print(result) '''
[docs] class MetricQuality(MetricImage):
[docs] def _load_image(self, img: Union[Image.Image, np.ndarray, str]) -> torch.Tensor: img_obj: PILImage if isinstance(img, str): img_obj = Image.open(img) elif isinstance(img, np.ndarray): arr = img if arr.dtype != np.uint8: arr = (arr * 255).astype(np.uint8) img_obj = Image.fromarray(arr) else: img_obj = img tensor: torch.Tensor = torch.from_numpy( np.array(img_obj, copy=True) ).float() assert tensor.ndim == 3 assert tensor.shape[2] in {1, 3} tensor = tensor.permute(2, 0, 1) / 255.0 # [H, W, C] -> [C, H, W] return tensor.to(self.device, non_blocking=True)
[docs] def score(self, image: Union[Image.Image, np.ndarray, str]) -> Dict[str, float]: image_tensor: torch.Tensor = self._load_image(image).unsqueeze(0) assert image_tensor.ndim == 4 assert image_tensor.dtype == torch.float32 assert 0.0 <= float(image_tensor.min()) assert float(image_tensor.max()) <= 1.0 with torch.no_grad(): score: torch.Tensor = piq.brisque( image_tensor, data_range=1.0 ) return {"brisque": float(score.item())}
[docs] def score_batch( self, images: List[Union[Image.Image, np.ndarray, str]], ) -> List[Dict[str, float]]: tensors: List[torch.Tensor] = [self._load_image(img) for img in images] shapes = {t.shape for t in tensors} if len(shapes) != 1: raise ValueError( "All images must have identical shape for batching. " f"Found shapes: {shapes}" ) batch: torch.Tensor = torch.stack(tensors, dim=0).to(self.device, non_blocking=True) # [N, C, H, W] assert batch.ndim == 4 assert batch.dtype == torch.float32 assert 0.0 <= float(batch.min()) assert float(batch.max()) <= 1.0 with torch.no_grad(): scores: torch.Tensor = piq.brisque( # [N] batch, data_range=1.0, reduction="none", ) assert scores.ndim == 1, f"Return shape is {scores.ndim}" assert scores.shape[0] == batch.shape[0] results: List[Dict[str, float]] = [ {"brisque": float(v.item())} for v in scores ] assert len(results) == len(images) return results
# This is how to use it ''' import torch from PIL import Image image = Image.open('assets/male.jpg') metric_quality = MetricQuality() result = metric_quality.score(image) print(result) '''