← Back to Directory

Row Details #4865

Data:

{
  "text": "I found an implementation of Adam in [this SO question](https://stackoverflow.com/questions/51387194/implementing-adam-in-pytorch):\n\n    class ADAMOptimizer(torch.optim.Optimizer):\n        \"\"\"\n        implements ADAM Algorithm, as a preceding step.\n        \"\"\"\n        def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):\n            defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n            super(ADAMOptimizer, self).__init__(params, defaults)\n    \n        def step(self):\n            \"\"\"\n            Perform a single optimization step.\n            \"\"\"\n            loss = None\n            for group in self.param_groups:\n    \n                for p in group['params']:\n                    grad = p.grad.data\n                    state = self.state[p]\n    \n                    # State initialization\n                    if len(state) == 0:\n                        state['step'] = 0\n                        # Momentum (Exponential MA of gradients)\n                        state['exp_avg'] = torch.zeros_like(p.data)\n    \n                        # RMS Prop componenet. (Exponential MA of squared gradients). Denominator.\n                        state['exp_avg_sq'] = torch.zeros_like(p.data)\n    \n                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n    \n                    b1, b2 = group['betas']\n                    state['step'] += 1\n    \n                    # Add weight decay if any\n                    if group['weight_decay'] != 0:\n                        grad = grad.add(group['weight_decay'], p.data)\n    \n                    # Momentum\n                    exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad\n                    \n                    # RMS\n                    exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)\n    \n                    mhat = exp_avg / (1 - b1 ** state['step'])\n                    vhat = exp_avg_sq / (1 - b2 ** state['step'])\n                    \n                    denom = torch.sqrt( vhat + group['eps'] )\n    \n                    p.data = p.data - group['lr'] * mhat / denom \n                    \n                    # Save state\n                    state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq \n    \n            return loss\n\nMy issue is that a lot of my gradients have a 0 value, which messes  up the momentum and velocity terms. What I'm interested in is modifying  the code such that **0 values will not be taken into account when  calculating the momentum and velocity terms (i.e., first and  second-moment estimates).**\n\nThough, I'm unsure how to do that. If it's a simple network where the gradients are just simple dimensions I can check whether p.grad.data=0,  but since this is going to be a multi-dimension tensor I'm unsure how  to remove the zeros in the calculations and not mess something else  (e.g., the remaining updates).",
  "label": "r/pytorch",
  "dataType": "post",
  "communityName": "r/pytorch",
  "datetime": "2024-04-23",
  "username_encoded": "Z0FBQUFBQm5LakwyanRhcURaVzVqNURySUlyaFcwZXlnUzVLTVJRMEZlcUd1RU5uR3VyZ0hpdzM0WlNLOWlDRWZjZWZLSHVFNzBlY1cycWVleEtxWlRRa2hLdUlKNVVaOGc9PQ==",
  "url_encoded": "Z0FBQUFBQm5Lak9GWXBkcEVBTVB6eVJDVWlPU3hyaDFRRWlVcWJVTnlRQkZabVY3ZlEzS1hTZXd4ckg2UWtFN1dNbVYyMTc4LXZublpVeVBFSTllMS1vbk9BbXhYWTNjRUFzdnRwdDI2OVhFQk5QV0N5QTdtMmo1OGxzUEI3dzVLN1F6NGV4ZlU5VFNFelN6QWZMNTFBYTlKOFAzRHR4clVsNl9uZkNRWHFnUld3d2hZVTRuMC02LWMzX0Q4cWpIMmRtOWVEVkZvUzljRm43OVFYTy1YcUNzWW9xejJOM2pQZz09"
}