register_buffer
-
torch.nn.Module.register_buffer
-
This is typically used to register a buffer that should not to be considered a model parameter.
-
The registered tensors
-
are not updated by an optimizer, since they are not weights(parameters).
-
are saved into
torch.nn.Module.state_dict.
-
are moved to the device (e.g., CPU/GPU) along with the module when calling
.to() method.
-
can be accessed by the
.named_buffers() method.
>>> import torch
>>> import torch.nn as nn
>>>
>>> class ModelWithoutRegisterBuffer(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(5, 5)
... self.table = torch.randn(5, 5)
...
>>> model_wo = ModelWithoutRegisterBuffer()
>>> model_wo.state_dict
>>> model_wo.state_dict().keys()
odict_keys(['linear.weight', 'linear.bias'])
>>>
>>> next(model_wo.parameters()) # self.linear` parameters only
Parameter containing:
tensor([[ 0.2549, 0.0874, -0.2376, 0.1055, 0.2452],
[-0.1870, 0.2541, 0.1141, -0.2168, 0.0588],
[ 0.3015, -0.4444, -0.0167, -0.1264, -0.0926],
[-0.1709, 0.0621, -0.3716, 0.1197, 0.3451],
[ 0.3995, 0.3425, -0.1166, -0.0362, 0.0492]], requires_grad=True)
>>>
>>> model_wo.to("cuda")
>>> model_wo.table.is_cuda
False
>>>
>>>
>>>
>>> class ModelWithRegisterBuffer(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(5, 5)
... self.register_buffer("table", torch.randn(5, 5))
...
>>> model_w = ModelWithRegisterBuffer()
>>> model_w.state_dict
>>> model_w.state_dict().keys()
odict_keys(['table', 'linear.weight', 'linear.bias'])
>>>
>>> next(model_w.parameters()) # self.linear` parameters only
Parameter containing:
tensor([[ 0.1504, -0.0401, -0.1466, -0.1018, 0.3968],
[ 0.2359, 0.1086, 0.2259, 0.2023, -0.0552],
[ 0.2171, 0.0971, -0.3293, 0.1233, -0.0934],
[ 0.3125, -0.0515, -0.3245, 0.2853, -0.0398],
[-0.3445, 0.3251, 0.3733, 0.0704, -0.2582]], requires_grad=True)
>>>
>>> model_w.to("cuda")
>>> model_w.table.is_cuda
True
>>>
>>> next(model_w.named_buffers()) # access to the registered tensors
('table', tensor([[ 0.5122, 0.0684, -0.5352, 0.9809, 0.6826],
[-0.1609, 0.1757, -1.5572, 0.5408, -0.6608],
[-0.1418, -0.7139, 0.0371, -1.2819, 0.6693],
[-0.1744, 0.3132, -1.0427, -1.2012, -0.9954],
[ 2.0804, -0.2535, -0.2862, 0.4413, -0.7968]], device='cuda:0'))