Source code for vision_unlearning.utils.training

import random
from typing import Optional
import numpy as np
import torch
from diffusers.utils.torch_utils import is_compiled_module
from accelerate import Accelerator


[docs] def tokenize_captions(examples, tokenizer, caption_column, is_train=True): ''' Adapted from The HuggingFace Inc. team. All rights reserved. Licensed under the Apache License, Version 2.0. Source: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py ''' captions = [] for caption in examples[caption_column]: if isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids
[docs] def unwrap_model(model, accelerator): ''' Adapted from The HuggingFace Inc. team. All rights reserved. Licensed under the Apache License, Version 2.0. Source: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py ''' model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model
[docs] def forget_tokens(examples, tokenizer, caption_column, forget_prompt: str): length = len(examples[caption_column]) captions = [forget_prompt] * length inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids
[docs] def preprocess_train(examples, tokenizer, caption_column, image_column, train_transforms, overwrite_column: Optional[str] = None, concept_overwrite: Optional[str] = None): ''' Adapted from The HuggingFace Inc. team. All rights reserved. Licensed under the Apache License, Version 2.0. Source: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py concept_overwrite: concept to be used for overwriting, described as an textual string (used to modify the prompt). TODO: this handling of concept_overwrite is weird... I wish this were somewhat more structured/organized/clear. For example, the overwriting string may need a more complex prompt than just "an image of f{concept_overwrite}", or with a different article ''' images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples, tokenizer, caption_column) if overwrite_column is not None: # get tokens from caption_overwrite_column examples["forget_ids"] = tokenize_captions(examples, tokenizer, overwrite_column) elif concept_overwrite is not None: # get tokens from hardcoded example with class examples["forget_ids"] = forget_tokens(examples, tokenizer, caption_column, f"An image of {concept_overwrite}") return examples
[docs] def collate_fn(examples): ''' Adapted from The HuggingFace Inc. team. All rights reserved. Licensed under the Apache License, Version 2.0. Source: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py ''' pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) result = {"pixel_values": pixel_values, "input_ids": input_ids} if "forget_ids" in examples[0]: # This happens when `preprocess_train` was called with a non-none `concept_overwrite` result["forget_ids"] = torch.stack([example["forget_ids"] for example in examples]) return result
[docs] def launch_accelerated_training(unlearner: 'Unlearner'): # type: ignore ''' Wrap your training function with the accelerator ''' accelerator = Accelerator(mixed_precision="fp16", dynamo_backend="no") with accelerator.local_main_process_first(): if accelerator.is_local_main_process: unlearner.train() accelerator.wait_for_everyone() # Wait for all processes to finish