In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear = nn.Linear(5, 5, bias=False)
        self.embeds = nn.Embedding(5, 5)
        self.params = nn.Parameter(torch.FloatTensor(5, 5))
In [2]:
model = Model()
for parameter in model.parameters():
    print(parameter, "\n")
Parameter containing:
tensor([[-2.5368e+30,  9.8651e-43,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
       requires_grad=True) 

Parameter containing:
tensor([[-0.0252, -0.1112,  0.2171,  0.3021, -0.4214],
        [-0.2401,  0.0265,  0.3273,  0.3718, -0.3386],
        [-0.2192, -0.4462, -0.1716,  0.0509,  0.3626],
        [-0.3824,  0.3818,  0.4069, -0.0667, -0.1942],
        [ 0.3055,  0.3448, -0.1002,  0.0351,  0.2229]], requires_grad=True) 

Parameter containing:
tensor([[ 0.1807, -1.0871, -0.2045,  1.4291, -0.9593],
        [-1.2964, -0.8838,  0.9947, -0.3255,  0.8467],
        [ 2.5978,  0.0065,  0.1966,  1.0692,  0.8780],
        [ 1.3668, -1.8618, -1.5452, -1.5488,  0.0560],
        [ 0.0663, -1.4776,  1.1258, -0.0208, -0.7309]], requires_grad=True) 

In [3]:
optimzer = optim.Adam(
    [
        {"params": model.linear.parameters(), "lr": 1e-3},
        {"params": model.embeds.parameters(), "lr": 1e-4},
        {"params": model.params, "lr": 1e-5},
    ]
)

# Adam (
# Parameter Group 0
#     amsgrad: False
#     betas: (0.9, 0.999)
#     capturable: False
#     differentiable: False
#     eps: 1e-08
#     foreach: None
#     fused: None
#     lr: 0.001
#     maximize: False
#     weight_decay: 0

# Parameter Group 1
#     amsgrad: False
#     betas: (0.9, 0.999)
#     capturable: False
#     differentiable: False
#     eps: 1e-08
#     foreach: None
#     fused: None
#     lr: 0.0001
#     maximize: False
#     weight_decay: 0

# Parameter Group 2
#     amsgrad: False
#     betas: (0.9, 0.999)
#     capturable: False
#     differentiable: False
#     eps: 1e-08
#     foreach: None
#     fused: None
#     lr: 1e-05
#     maximize: False
#     weight_decay: 0
# )