Let's say the model consists of two parts: the generator and the allocator.
It includes attributes for the wall_generator and the room_allocator.
class PlanGenerator(nn.Module):
"""Floor plan generator consisting of the `WallGenerator` and `RoomAllocator`"""
def __init__(self, configuration: Configuration):
super().__init__()
self.configuration = configuration
self.wall_generator = WallGenerator(
in_channels=self.configuration.WALL_GENERATOR_IN_CHANNELS,
out_channels=self.configuration.WALL_GENERATOR_OUT_CHANNELS,
size=self.configuration.IMAGE_SIZE,
channels_step=self.configuration.WALL_GENERATOR_CHANNELS_STEP,
encoder_repeat=self.configuration.WALL_GENERATOR_REPEAT,
)
self.room_allocator = RoomAllocator(
in_channels=self.configuration.ROOM_ALLOCATOR_IN_CHANNELS,
out_channels=self.configuration.ROOM_ALLOCATOR_OUT_CHANNELS,
size=self.configuration.IMAGE_SIZE,
channels_step=self.configuration.ROOM_ALLOCATOR_CHANNELS_STEP,
encoder_repeat=self.configuration.ROOM_ALLOCATOR_REPEAT,
)
self.to(self.configuration.DEVICE)
def forward(
self, floor_batch: torch.Tensor, walls_batch: torch.Tensor, masking: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
generated_walls = self.wall_generator(floor_batch)
allocated_rooms = self.room_allocator(walls_batch)
if masking:
generated_walls_masked = self.mask(generated_walls, floor_batch)
allocated_rooms_masked = self.mask(allocated_rooms, floor_batch)
return generated_walls_masked, allocated_rooms_masked
return generated_walls, allocated_rooms