vision_unlearning.utils.gradient_weighting

Attributes

logger

Classes

GradientWeightingMethod

Method used to conciliate/harmonize/combine/weight the gradients of the different tasks

GradientWeightingMethodNone

No weighting is applied, takes just the forget gradients

GradientWeightingMethodSimple

Fixed weights for each component

GradientWeightingMethodMunba

Inspired by @misc{wu2025munbamachineunlearningnash,

Module Contents

vision_unlearning.utils.gradient_weighting.logger
class vision_unlearning.utils.gradient_weighting.GradientWeightingMethod(/, **data: Any)[source]

Bases: pydantic.BaseModel, abc.ABC

Method used to conciliate/harmonize/combine/weight the gradients of the different tasks

Inspired by @article{navon2022multi,

title={Multi-Task Learning as a Bargaining Game}, author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan}, journal={arXiv preprint arXiv:2202.01017}, year={2022}

} Source: https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py

abstract weight_grads(grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) torch.Tensor[source]

@return scaled_grad

class vision_unlearning.utils.gradient_weighting.GradientWeightingMethodNone(/, **data: Any)[source]

Bases: GradientWeightingMethod

No weighting is applied, takes just the forget gradients

For debugging/comparison purposes

weight_grads(grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) torch.Tensor[source]

@return scaled_grad

class vision_unlearning.utils.gradient_weighting.GradientWeightingMethodSimple(/, **data: Any)[source]

Bases: GradientWeightingMethod

Fixed weights for each component

forget_weight: float = 1.0
retain_weight: float = 1.0
weight_grads(grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) torch.Tensor[source]

@return scaled_grad

class vision_unlearning.utils.gradient_weighting.GradientWeightingMethodMunba(/, **data: Any)[source]

Bases: GradientWeightingMethod

Inspired by @misc{wu2025munbamachineunlearningnash,

title={MUNBa: Machine Unlearning via Nash Bargaining}, author={Jing Wu and Mehrtash Harandi}, year={2025}, eprint={2411.15537}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/pdf/2411.15537v1},

} The closed-form solution is implemented as described in the V1 of the paper.

weight_grads(grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) torch.Tensor[source]

@return scaled_grad