from abc import ABC, abstractmethod
from typing import Union, Optional, Any, Dict, List, Literal
import tempfile
import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage
import torch
from transformers import (
pipeline,
AutoImageProcessor,
SiglipForImageClassification,
)
from transformers.pipelines.image_classification import ImageClassificationPipeline
import piq
from vision_unlearning.metrics.base import Metric
# TODO take these pseudo tests and examples and transform into automated test
[docs]
class MetricImage(Metric, ABC):
'''
Based only on the image itself
e.g., image quality, painting style
'''
[docs]
@abstractmethod
def score(self, image: Image.Image) -> Dict[str, Any]:
pass
[docs]
class MetricPaintingStyle(MetricImage):
metrics: List[Literal['is_desired_style', 'desired_style_confidence']] = [] # TODO: this is currently ignored
desired_style: str
top_k: int = 5
model_path: str
device: Optional[torch.device | str | int] = 'cuda'
_pipeline: Optional[ImageClassificationPipeline] = None
[docs]
def model_post_init(self, __context: Optional[dict] = None) -> None:
self._pipeline = pipeline('image-classification', model=self.model_path, device=self.device)
[docs]
def score(self, image: Image.Image) -> Dict[str, bool | float]:
assert self._pipeline is not None
scores = {
'is_desired_style': False,
'desired_style_confidence': 0.0
}
predictions: list = self._pipeline(image, top_k=self.top_k)
for p in predictions:
if p['label'] == self.desired_style:
scores['is_desired_style'] = True
scores['desired_style_confidence'] = float(p['score'])
return scores
# Pseudo test
# import torch
# from PIL import Image
#
# #image = Image.open('assets/Diffusion-MU-Attack/files/dataset/vangogh/imgs/35_0.png')
# image = Image.open('assets/Diffusion-MU-Attack/files/dataset/i2p_nude/imgs/1011_0.png')
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# metric_painting_style = MetricPaintingStyle(desired_style='vincent-van-gogh', top_k=3, model_path='assets/models_pretrained/style_classifier/results/checkpoint-2800', device=device)
# result = metric_painting_style.score(image)
# print(result)
[docs]
class MetricRace(MetricImage):
"""
Race classification using Hugging Face model:
syntheticbot/clip-face-attribute-classifier
Requires the following additional dependencies:
* tf_keras = "~2.19.0"
* tensorrt = "~10.13.2"
* blinker = "~1.9.0"
"""
# TODO: if we could do this with a HF model model be better, no need for additional libs
[docs]
def model_post_init(self, __context: Optional[dict] = None) -> None:
try:
from deepface import DeepFace # noqa
self.DeepFace = DeepFace
except ImportError as e:
raise ImportError("DeepFace library is required for MetricRace. Please install it via 'pip install deepface'. Recommended version: deepfaces = '~0.0.95', tf_keras = '~2.19.0', tensorrt = '~10.13.2'") from e
[docs]
def score(self, image: Image.Image) -> Dict[str, str]:
results = self.DeepFace.analyze(
np.array(image.convert('RGB')),
actions=['race'],
enforce_detection=False,
)
# DeepFace may return list if multiple faces
if isinstance(results, list):
results = results[0]
return {
"race": results.get("dominant_race"),
}
# Example usage
'''
img = Image.open("assets/datasets/lfw_splits/George_W_Bush/train_forget/George_W_Bush_0001.jpg")
metric_race = MetricRace()
print(metric_race.score(img))
'''
[docs]
class MetricGender(MetricImage):
device: Optional[torch.device | str | int] = 'cpu'
_model_name: str = "prithivMLmods/Realistic-Gender-Classification"
_id2label = {0: 'female', 1: 'male'}
_model: Any
_processor: Any
[docs]
def model_post_init(self, __context):
# load processor & model from HF
self._processor = AutoImageProcessor.from_pretrained(self._model_name)
self._model = SiglipForImageClassification.from_pretrained(self._model_name)
self._model.to(self.device).eval()
[docs]
def score(self, image: Image.Image) -> Dict[str, Union[Literal['male', 'female'], float]]:
# ensure RGB and prepare batch
img = image.convert("RGB")
inputs = self._processor(images=img, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits[0]
probs = torch.softmax(logits, dim=-1)
idx = int(torch.argmax(probs))
label: Literal['male', 'female'] = self._id2label[idx] # type: ignore
confidence: float = float(probs[idx])
return {
'gender': label,
'gender_confidence': confidence
}
# This is how to use it
'''
import torch
from PIL import Image
image = Image.open('assets/male.jpg')
#image = Image.open('assets/female.jpg')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric_gender = MetricGender(device=device)
result = metric_gender.score(image)
print(result)
'''
[docs]
class MetricQuality(MetricImage):
[docs]
def _load_image(self, img: Union[Image.Image, np.ndarray, str]) -> torch.Tensor:
img_obj: PILImage
if isinstance(img, str):
img_obj = Image.open(img)
elif isinstance(img, np.ndarray):
arr = img
if arr.dtype != np.uint8:
arr = (arr * 255).astype(np.uint8)
img_obj = Image.fromarray(arr)
else:
img_obj = img
tensor: torch.Tensor = torch.from_numpy(
np.array(img_obj, copy=True)
).float()
assert tensor.ndim == 3
assert tensor.shape[2] in {1, 3}
tensor = tensor.permute(2, 0, 1) / 255.0 # [H, W, C] -> [C, H, W]
return tensor.to(self.device, non_blocking=True)
[docs]
def score(self, image: Union[Image.Image, np.ndarray, str]) -> Dict[str, float]:
image_tensor: torch.Tensor = self._load_image(image).unsqueeze(0)
assert image_tensor.ndim == 4
assert image_tensor.dtype == torch.float32
assert 0.0 <= float(image_tensor.min())
assert float(image_tensor.max()) <= 1.0
with torch.no_grad():
score: torch.Tensor = piq.brisque(
image_tensor, data_range=1.0
)
return {"brisque": float(score.item())}
[docs]
def score_batch(
self,
images: List[Union[Image.Image, np.ndarray, str]],
) -> List[Dict[str, float]]:
tensors: List[torch.Tensor] = [self._load_image(img) for img in images]
shapes = {t.shape for t in tensors}
if len(shapes) != 1:
raise ValueError(
"All images must have identical shape for batching. "
f"Found shapes: {shapes}"
)
batch: torch.Tensor = torch.stack(tensors, dim=0).to(self.device, non_blocking=True) # [N, C, H, W]
assert batch.ndim == 4
assert batch.dtype == torch.float32
assert 0.0 <= float(batch.min())
assert float(batch.max()) <= 1.0
with torch.no_grad():
scores: torch.Tensor = piq.brisque( # [N]
batch,
data_range=1.0,
reduction="none",
)
assert scores.ndim == 1, f"Return shape is {scores.ndim}"
assert scores.shape[0] == batch.shape[0]
results: List[Dict[str, float]] = [
{"brisque": float(v.item())} for v in scores
]
assert len(results) == len(images)
return results
# This is how to use it
'''
import torch
from PIL import Image
image = Image.open('assets/male.jpg')
metric_quality = MetricQuality()
result = metric_quality.score(image)
print(result)
'''