from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import List
import torch
from vision_unlearning.utils.logger import get_logger
logger = get_logger('gradient_weighting')
[docs]
class GradientWeightingMethod(BaseModel, ABC):
'''
Method used to conciliate/harmonize/combine/weight the gradients of the different tasks
Inspired by @article{navon2022multi,
title={Multi-Task Learning as a Bargaining Game},
author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan},
journal={arXiv preprint arXiv:2202.01017},
year={2022}
}
Source: https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py
'''
[docs]
@abstractmethod
def weight_grads(self, grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) -> torch.Tensor:
'''
@return scaled_grad
'''
pass
[docs]
class GradientWeightingMethodNone(GradientWeightingMethod):
'''
No weighting is applied, takes just the forget gradients
For debugging/comparison purposes
'''
[docs]
def weight_grads(self, grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) -> torch.Tensor:
return torch.cat([g.view(-1) for g in grads_forget])
[docs]
class GradientWeightingMethodSimple(GradientWeightingMethod):
'''
Fixed weights for each component
'''
forget_weight: float = 1.0
retain_weight: float = 1.0
[docs]
def weight_grads(self, grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) -> torch.Tensor:
grads_forget_scaled = self.forget_weight * torch.cat([g.view(-1) for g in grads_forget])
grads_retain_scaled = self.retain_weight * torch.cat([g.view(-1) for g in grads_retain])
return grads_forget_scaled + grads_retain_scaled
[docs]
class GradientWeightingMethodMunba(GradientWeightingMethod):
'''
Inspired by @misc{wu2025munbamachineunlearningnash,
title={MUNBa: Machine Unlearning via Nash Bargaining},
author={Jing Wu and Mehrtash Harandi},
year={2025},
eprint={2411.15537},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/pdf/2411.15537v1},
}
The closed-form solution is implemented as described in the V1 of the paper.
'''
[docs]
def weight_grads(self, grads_forget: List[torch.Tensor], grads_retain: List[torch.Tensor], accelerator) -> torch.Tensor:
# Stack gradients to form matrix G
G = torch.stack([
torch.cat([g.view(-1) for g in grads_retain]),
torch.cat([g.view(-1) for g in grads_forget])
])
K = G @ G.T # Compute K = G^T G; It is a 2x2 tensor
# Possible variation: K /= torch.norm(K) # As recomended here: https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py#L231
# Solve for α using narsh equation
k11, k12, k22 = K[0, 0], K[0, 1], K[1, 1]
alpha_retain = torch.sqrt((2 * k11 * k22 + k12 * torch.sqrt(k11 * k22)) / (k11**2 * k22 - k11 * k12**2)) # This is a Tensor of shape [], aka is a float
alpha_forget = (1 - k11 * alpha_retain**2) / (k12 * alpha_retain)
alpha = torch.tensor([alpha_retain, alpha_forget]).reshape(2, 1) # Typical values seem to be things like [0.0016, -0.0029]
logger.debug(f"Alpha in this iteration: {alpha}")
G = G.to(accelerator.device)
alpha = alpha.to(accelerator.device)
scaled_grad = G.T @ alpha
# Possible variations:
# scaled_grad /= 2*torch.abs(alpha).min()
# scaled_grad /= 2*alpha.min()
# scaled_grad /= torch.norm(alpha)
# if <differentSign>: scaled_grad = G.T @ alpha; else: ... # As recommended here: https://github.com/JingWu321/MUNBa/blob/d691e13885a373d97e4177cb051bd0dc64a9c732/SD/MUNBa_cls.py#L271
return scaled_grad