vision_unlearning.unlearner.uce_sd_erase

Classes

ConceptType

Enum representing the type of concept to unlearn.

UCE

Unified Concept Editing for unlearning in Stable Diffusion models.

Module Contents

class vision_unlearning.unlearner.uce_sd_erase.ConceptType

Bases: str, enum.Enum

Enum representing the type of concept to unlearn.

Object = 'object'
Art = 'art'
class vision_unlearning.unlearner.uce_sd_erase.UCE(**data: Any)

Bases: vision_unlearning.unlearner.base.Unlearner

Unified Concept Editing for unlearning in Stable Diffusion models. Adapted from:

GitHub: https://github.com/rohitgandikota/unified-concept-editing Arxiv: https://arxiv.org/pdf/2308.14761.pdf Gandikota, R., Orgad, H., Belinkov, Y., Materzyńska, J., & Bau, D. (2024). Unified concept editing in diffusion models. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 5111-5120).

This unlearner do not use LoRA, and do not perform any fine-tuning (instead, it performs a closed-form weight update).

pretrained_model_name_or_path: str = None
erase_scale: float = None
preserve_scale: float = None
lamb: float = None
save_entire_model: bool = None
edit_concepts: str | None = None
guide_concepts: str | None = None
preserve_concepts: str | None = None
concept_type: ConceptType = None
expand_prompts: bool = True
final_eval_prompts_forget: str | List[str] = None
final_eval_prompts_retain: str | List[str] = None
output_dir: str = None
device: str = 'cuda:0'
compute_runtimes: bool = None
hub_model_id: str | None = None
_collect_text_embeddings(pipe: Any, concepts: list[str], device: str, torch_dtype: torch.dtype) dict[str, torch.Tensor]

Return dict {concept: last_token_embedding}.

_collect_guide_outputs(concepts: list[str], embeds: dict[str, torch.Tensor], modules: list[torch.nn.Module]) dict[str, list[torch.Tensor]]

Collect cross-attention outputs for guide/preserve concepts.

_update_weights(original_modules: list[torch.nn.Module], erase_embeds: dict[str, torch.Tensor], guide_outputs: dict[str, list[torch.Tensor]], edit_concepts: list[str], guide_concepts: list[str], preserve_concepts: list[str], erase_scale: float, preserve_scale: float, lamb: float, device: str, torch_dtype: torch.dtype) list[torch.nn.Module]

Apply the UCE weight update to each module and return new modules.

_save_uce_weights(uce_modules: list[torch.nn.Module], uce_module_names: list[str]) None

Save updated module weights to a safetensors file.

train() List[huggingface_hub.repocard_data.EvalResult]

Main UCE training and concept erasure logic.

static get_pipeline_from_modified_weights(pretrained_model_name_or_path: str, device: str | torch.device, output_dir: str) diffusers.DiffusionPipeline
evaluate() Tuple[List[huggingface_hub.repocard_data.EvalResult], Dict[str, PIL.Image.Image]]