Latent Shapes
Introduction
In the best of my knowledge, latent vectors are initialized from a normal distribution and then updated during training to minimize the loss.
Although linear interpolation can be used in the manipulations between latent vectors,
and similar vectors can be placed in spatially close locations, the latent vectors themselves do not have a specific shape.
This project takes inspiration from this
perspective and explores two main conceptions:
1) Latent vectors with geometric shape ─ exploring what happens if we give the latent vectors a geometric structure, and how that might help us better understand or control the results;
2) Interactive manipulation using the latent vector ─ making it possible for people to easily change and play with the latent vectors, so they can see and understand the effects right away and create new shapes or designs interactively.
Latent Vector Manipulation in real-time
with Mouse Dragging
Data Preparation
This project is based on the chair dataset (03001627) from ShapeNetCore. You can access the original data
here.
First, let's look at the meshes in ShapeNetCore. The dataset contains 6778 data points, all of which are already scaled to small values and are not zero-centered.
Additionally, a few data points are not aligned to the AABB axis. Therefore, all meshes must first pass the following steps.
In this, the DeepSDF model architecture will be used as the mesh reconstruction network, so all meshes must also be watertight.
-
watertight: make each mesh watertight to ensure correct SDF values computation
-
centralize: translate each mesh to be zero-centered
-
orient: align each mesh to the AABB axis
-
sdf values compute: use these values to train a model for mesh reconstruction
Data processing steps
of 1007e20d5e811b308351982a6e40cf41
Now, as the final step, let's
compute latent shapes. The processes for computing them are:
1) creating a bounding box of the mesh;
2) applying a scale matrix to the bounding box;
3) subdividing the bounding box n times;
4) adjusting bounding box vertices to the point intersected with a ray or the nearest point on the mesh
(uses negative normal vectors if top and bottom faces of the bounding mesh else the nearest points on the mesh);
After all processing steps, the data is saved in *.npz format with the following data: xyz coordinates, sdf values, latent shape, faces (latent shape face indices), and class number (shape embedding number).
The code for processing the data can be found
here.
Latent shape computation
of 100b18376b885f206ae9ad7e32c4139d
Architecture
The training target to optimize consists of
SDFDecoder and the LatentShapes.
The SDFDecoder follows the architecture used in DeepSDF, where the initial input data is skip-connected to the input of each block.
A number of blocks, and a number of layers of a block can be adjusted in config.py.
The class for
LatentShapes is defined as follows. Latent shapes' embedding attribute is initialized to pre-computed latent shapes above with noise (in this project, the values
noise_min = -0.1
and
noise_max = 0.1
are used).
In the forward pass, it returns the embeddings corresponding to the given classes.
class LatentShapes(nn.Module):
def __init__(
self, latent_shapes: torch.Tensor, noise_min: Optional[float] = None, noise_max: Optional[float] = None
):
super().__init__()
self.noise = torch.zeros_like(latent_shapes)
if None not in (noise_min, noise_max):
# add noise to latent shapes
self.noise = noise_min + torch.rand_like(latent_shapes) * (noise_max - noise_min)
# check if the noise range is valid
assert torch.all(self.noise >= noise_min)
assert torch.all(self.noise <= noise_max)
# initialize latent shapes with noise
self.embedding = nn.Parameter(latent_shapes + self.noise)
def forward(self, class_number: torch.Tensor) -> torch.Tensor:
return self.embedding[class_number]
The diagram below visually indicates the entire training process of the models.
Each point at the upper left in the figure represents 3D \(\text{xyz (1, 3)}\) coordinate used as an input.
This coordinate is combined with the \(\text{latent shape (1, 98×3)}\), so the initial input for the forward pass is a \(\text{(1, 297)}\)-shaped vector.
SDFDecoder is trained on this combined data to predict an \(\text{sdf}\) value.
Training process
The latent shape at the lower left indicates that the \(\text{latent shape}\) embedding is initialized with noise, as I described above.
Training latent shapes are gradually updated to the originals, and this approach enables flexible representation, allowing better reconstruction performance even when the latent shapes are not perfectly matched to the originals.
Lastly, the losses to optimize them are two. The one is \( \mathbf{Loss_{sdf}} \)
that computes the difference between the predicted value and the ground truth value.
The \( \mathbf{Loss_{shape}} \) is another one that guides the embedding to be the original latent shape.
Through these losses, the LatentShapes and the SDFDecoder are optimized and improved simultaneously.
Training Models
In this project, NVIDIA GeForce RTX 4090 is used to train models for 120 epochs using only \(50\) chairs of data.
Each data point has \(36^3\) xyz coordinates. Therefore the total number of xyz is \(36^3 \times 50 = 2332800\)
Since the GPU doesn't have enough memory to load this much data into memory all at once (OOM),
it should be loaded only when needed.
The following
code is the implementation way.
First, the
cumulative_length
is pre-computed based on the total number of samples per data point at initialization time.
Then, given an index, the corresponding file is loaded,
and the index is mapped to the correct sample within that file by using modulo with the total sampling number.
class SDFDataset(Dataset):
(...)
def __len__(self) -> int:
return self.total_length
def __getitem__(self, _idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
for file_idx, cumulative_length in enumerate(self.cumulative_length):
if _idx < cumulative_length:
file_idx -= 1
break
data = np.load(self.data_path[file_idx])
idx = _idx % self.configuration.N_TOTAL_SAMPLING
xyz = torch.tensor(data["xyz"][idx], dtype=torch.float32)
sdf = torch.tensor(data["sdf"][idx], dtype=torch.float32)
class_number = torch.tensor(data["class_number"], dtype=torch.long)
latent_shape = torch.tensor(data["latent_shape"], dtype=torch.float32)
faces = torch.tensor(data["faces"], dtype=torch.long)
xyz = xyz.to(self.configuration.DEVICE)
sdf = sdf.to(self.configuration.DEVICE)
class_number = class_number.to(self.configuration.DEVICE)
latent_shape = latent_shape.to(self.configuration.DEVICE)
faces = faces.to(self.configuration.DEVICE)
return xyz, sdf, class_number, latent_shape, faces
Below is the code used for
training at each epoch.
The variable
cxyz
represents the combined input described in the Architecture section above.
The SDFDecoder model takes this combined input and predicts sdf values in batches.
The model's weights are updated based on the
loss
that corresponds to \( \mathbf{Loss_{sdf}} \).
Since the LatentShapes are initialized with noise, the
loss_shape
that corresponds to \( \mathbf{Loss_{shape}} \) is used to update them
so that they gradually become closer to the original latent shapes.
def _train_each_epoch(self) -> Tuple[float, float]:
"""training method for each epoch
Returns:
Tuple[float, float]: training loss for the decoder, and the latent shape
"""
losses = []
losses_shape = []
iterator_train = tqdm(
enumerate(self.sdf_dataset.train_dataloader), total=len(self.sdf_dataset.train_dataloader)
)
for batch_index, data in iterator_train:
xyz_batch, sdf_batch, class_number_batch, latent_shapes_batch_r, _ = data
latent_shapes_batch = self.latent_shapes(class_number_batch)
latent_shapes_batch = latent_shapes_batch.reshape(latent_shapes_batch.shape[0], -1)
cxyz = torch.cat((xyz_batch, latent_shapes_batch), dim=1)
sdf_preds = self.sdf_decoder(cxyz)
sdf_preds = torch.clamp(sdf_preds, min=-self.configuration.CLAMP, max=self.configuration.CLAMP)
sdf_batch = torch.clamp(sdf_batch, min=-self.configuration.CLAMP, max=self.configuration.CLAMP)
loss_shape = torch.nn.functional.mse_loss(self.latent_shapes(class_number_batch), latent_shapes_batch_r)
loss_shape.backward()
self.latent_shapes_optimizer.step()
self.latent_shapes_optimizer.zero_grad()
loss = torch.nn.functional.l1_loss(sdf_preds, sdf_batch.unsqueeze(-1))
loss = loss / self.configuration.ACCUMULATION_STEPS
loss.backward()
if (batch_index + 1) % self.configuration.ACCUMULATION_STEPS == 0 or (batch_index + 1) == len(
self.sdf_dataset.train_dataloader
):
self.sdf_decoder_optimizer.step()
self.sdf_decoder_optimizer.zero_grad()
losses.append(loss.item() * self.configuration.ACCUMULATION_STEPS)
losses_shape.append(loss_shape.item())
loss_mean = torch.tensor(losses).mean().item()
loss_shape_mean = torch.tensor(losses_shape).mean().item()
return loss_mean, loss_shape_mean
The flag for saving the decoder model's state uses the
weighted sum of the training loss and validation loss.
The weights are set to 0.1 and 1.0, respectively, so that the model is saved with a focus on validation generalization performance,
while a small weight on the training loss ensures that the model can also reconstruct the training data well.
The following are the loss graphs per epoch. All the losses
loss_mean
,
loss_mean_val
,
loss_shape_mean
, and
loss_shape_mean_val
show stable convergence.
Next chapter, let's qualitatively evaluate the reconstruction performance.
Losses for 120 epochs
Qualitative Evaluation
The following figure shows the results of reconstructing 50 trained chairs using latent shape embeddings and the decoder (the marching cubes algorithm was applied at a resolution of \(192^3\)).
You can see that even the details, such as the legs and armrests of the chairs, are properly reconstructed.
Reconstructed chairs at 192 resolution
Let me now examine several results generated using the trained shape embeddings through ratio-based linear interpolation.
The data on the left and right represent the inputs to be interpolated, while the center shows the result interpolated according to the specified ratios.
The following ratios were used \( \text{[[0.50, 0.50], [0.30, 0.70], [0.60, 0.40], [0.45, 0.55]]} \).
Although this approach is effective for data generation, it remains difficult to predict the specific changes that occur when only a portion of the latent vector is manipulated.
Particularly, it isn't easy to predict which components of the vector correspond to specific geometric parts.
Interpolated chairs
So instead of relying on these static interpolation methods,
let's finally look at the core implementation of the interface that allows interactive manipulation of the latent vector.
Interactive Manipulation
Here, I briefly describe the implementation of the demo GIF shown in the introduction section.
First,
FastAPI
is used as the communication layer connecting the interface with the model.
The API returns results in JSON format, where both the reconstruction outputs and latent shape data are represented as mesh data based on vertices and faces.
The reconstruct API request is defined as follows.
Due to the Marching Cubes algorithm, there is a trade-off between reconstruction quality and reconstruction speed depending on the resolution size.
Therefore, this parameter is made adjustable by the user to balance fidelity and performance.
class ReconstructRequest(BaseModel):
latent_shapes: List[List[float]]
rescale: bool
map_z_to_y: bool
ensure_watertight: bool
resolution: int
(...)
@app.post("/api/reconstruct")
def reconstruct(request: ReconstructRequest):
try:
configuration.RECONSTRUCTION_GRID_SIZE = request.resolution
latent_shapes_tensor = torch.tensor(request.latent_shapes).to(configuration.DEVICE)
# map z to y to match the loaded latent shape into the xyz system
latent_shapes_tensor[:, [1, 2]] = latent_shapes_tensor[:, [2, 1]]
reconstruction_results = sdf_decoder.reconstruct(
latent_shapes=latent_shapes_tensor.unsqueeze(0),
save_path=os.path.join(os.path.dirname(__file__)),
check_watertight=request.ensure_watertight,
map_z_to_y=request.map_z_to_y,
add_noise=False,
rescale=request.rescale,
)
if reconstruction_results[0] is None:
raise HTTPException(status_code=400, detail="Reconstruction failed")
# Extract mesh data from the first result
mesh = reconstruction_results[0]
vertices = mesh.vertices.tolist()
faces = mesh.faces.tolist()
edges = mesh.edges.tolist()
return {
"message": "Reconstruction successful",
"vertices": vertices,
"faces": faces,
"edges": edges,
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Reconstruction failed: {str(e)}")
Finally, the last part is a construction of the Three.js-based interface.
The latent shapes are generated by naively shrink-wrapping a subdivided mesh box for the chair data, resulting in a mesh that can be directly rendered on a Three.js canvas.
The main objective is to allow users to manipulate the vertices of this mesh, thereby creating an
intuitive and interactive interpolation experience.
Three.js's SelectionBox and TransformControls form the central modules of this implementation, enabling precise vertex selection and transformation within the interface.
The key logic of the SelectionBox is in converting the browser's mouse coordinates to
Normalized Device Coordinates (NDC).
While the browser uses a coordinate system ranging from 0 to 1, Three.js requires coordinates in the -1 to 1 range for Raycaster or SelectionBox operations.
In the browser, the y-coordinate gets bigger as you move down the screen, but in NDC, the y-coordinate gets bigger as you move up. So, you need to flip the sign of the y-coordinate when converting.
function on_pointerup(event) {
is_mouse_down = false;
if (!is_selection_mode ) {
return;
}
selection_box.endPoint.set(
(event.clientX / window.innerWidth) * 2 - 1,
-(event.clientY / window.innerHeight) * 2 + 1,
0.5
);
const selections = selection_box.select();
for (const selected_index of selected_indices) {
latent_shape[selected_index].material.color.set(LATENT_SHAPE_SPHERE_COLOR);
}
selected_shape = [];
selected_indices = [];
for (const selection of selections) {
if (selection.name === LATENT_SHAPE) {
let selected_index = -1;
selected_shape.push(selection)
selection.material.color.set(LATENT_SHAPE_SPHERE_SELECTED_COLOR);
for (let i = 0; i < latent_shape.length; i++) {
const vertex = latent_shape[i];
if (
Math.abs(selection.position.x - vertex.position.x) <= 0.001
&& Math.abs(selection.position.y - vertex.position.y) <= 0.001
&& Math.abs(selection.position.z - vertex.position.z) <= 0.001
) {
selected_index = i;
break;
}
}
if (0 <= selected_index && selected_index < latent_shape.length) {
selected_indices.push(selected_index);
}
}
}
if (selected_shape.length > 0) {
first_selected_latent_shape = {
position: selected_shape[0].position.clone(),
relative_positions: selected_shape.map(shape => shape.position.clone())
};
transform_controls.attach(selected_shape[0]);
} else {
transform_controls.detach();
}
}
Because TransformControls in Three.js can only be attached to one object at a time,
this implementation attaches the controls to the first selected object.
When the user moves this object, the same movement offset
(dx, dy, dz)
is applied to the other selected objects, preserving their relative positions.
This approach ensures that users can interactively manipulate multiple vertices at once, even though TransformControls is attached to a single object.
function on_object_change(event) {
if (selected_shape.length === 0 || !first_selected_latent_shape) {
return;
}
const dx = selected_shape[0].position.x - first_selected_latent_shape.position.x;
const dy = selected_shape[0].position.y - first_selected_latent_shape.position.y;
const dz = selected_shape[0].position.z - first_selected_latent_shape.position.z;
for (let i = 1; i < selected_shape.length; i++) {
const original_pos = first_selected_latent_shape.relative_positions[i];
selected_shape[i].position.x = original_pos.x + dx;
selected_shape[i].position.y = original_pos.y + dy;
selected_shape[i].position.z = original_pos.z + dz;
}
const latent_shape_vertices = latent_shape.map(
child => [child.position.x, child.position.y, child.position.z]
);
scene.remove(latent_shape_wireframe);
latent_shape_wireframe = create_mesh(
latent_shape_vertices,
latent_shape_faces,
LATENT_SHAPE_WIREFRAME_MATERIAL,
);
scene.add(latent_shape_wireframe);
}
This project can be explored in detail in the
repository linked.
I will conclude this post by sharing a demo GIF illustrating the interactive latent shape manipulation.
Latent Shapes demo
References