import os
import json
from typing import List, Dict, Optional
from PIL import Image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub.repocard_data import EvalResult, ModelCardData
from vision_unlearning.utils.logger import get_logger
logger = get_logger('utils')
[docs]
def save_model_card(
repo_id: str,
base_model: str,
dataset_forget_name: str,
dataset_retain_name: str,
repo_folder: str,
images: Dict[str, Image.Image] = {},
eval_results: List[EvalResult] = [], # whenever possible, should have this names: https://huggingface.co/metrics
tags: List[str] = [],
hyperparameters: dict = {},
similarities_gr: List[float] = [],
similarities_gf: List[float] = [],
):
'''
The resulting file looks like this: https://github.com/huggingface/hub-docs/blob/main/modelcard.md
This looks hugginface-specific, so you may think it should be in `integrations/huggingface.py`, but it is actually a generic Readme saving
Adapted from The HuggingFace Inc. team. All rights reserved.
Licensed under the Apache License, Version 2.0.
Source: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
'''
os.makedirs(os.path.join(repo_folder, "images"), exist_ok=True)
img_str = ""
for name, image in images.items():
path_relative = os.path.join("images", f"{name}.png")
image.save(os.path.join(repo_folder, path_relative))
img_str += f"\n"
# TODO: this description is not appearing in the model card
model_description = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}.
The weights were fine-tuned for forgetting {dataset_forget_name} dataset, while retaining {dataset_retain_name}.
You can find some example images in the following.\n
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
model_card = populate_model_card(model_card, tags=tags)
model_card.data = ModelCardData(
model_name=repo_id,
eval_results=eval_results,
hyperparameters=hyperparameters, # goes as kwargs
)
model_card.save(os.path.join(repo_folder, "README.md"))
logger.info(f"Model card saved to {repo_folder}")
if len(similarities_gf) > 0 or len(similarities_gr) > 0:
logger.info('Saving gradient conflicts')
with open(os.path.join(repo_folder, "gradient_conflicts.json"), "w") as f:
json.dump({"forget": similarities_gf, "retain": similarities_gr}, f)
# TODO: code for easily displaying metrics and images; https://colab.research.google.com/drive/1jMhjn5uJ16dhCetSqasOemV82yH15Jg2