Latent Shapes
Objectives
Latent vectors are typically initialized from a normal distribution and then updated during training to minimize the loss function.
Although linear interpolation can be used between latent vectors,
and similar vectors can be placed in spatially close locations, the latent vectors themselves do not have a specific geometric shape or structure.
This project takes inspiration from this
perspective and explores two main ideas:
1) Latent vectors with geometric shape ─ exploring what happens if latent vectors are given a geometric structure, and how this can help us better understand, interpret, or control the results;
2) Interactive manipulation using the latent vector ─ making it possible for users to modify latent vectors easily,
so they can immediately see the effects and create new shapes or designs in an interactive way.
Latent Vector Manipulation in real-time
with Mouse Dragging
Data Preprocessing
This project uses the chair dataset (03001627) from ShapeNetCore. The original data is available
here.
First, I examine the meshes in ShapeNetCore. The dataset contains 6778 data points, and all meshes are already scaled to small values, but are not zero-centered.
In addition, some meshes are not aligned with the AABB axis. Therefore, all meshes must go through the following steps before training.
In this project, the DeepSDF model architecture is used as the mesh reconstruction network, so all meshes must also be watertight.
-
watertight: each mesh is made watertight to ensure correct computation of SDF values.
-
centralize: each mesh is translated to be zero-centered.
-
orient: each mesh is aligned to the AABB axis.
-
sdf values compute: the resulting SDF values are used to train a model for mesh reconstruction.
Data processing steps
of 1007e20d5e811b308351982a6e40cf41
Now, as the final step, the
latent shapes are computed. The process consists of the following steps:
1) creating a bounding box of the mesh;
2) applying a scale matrix to the bounding box;
3) subdividing the bounding box n times;
4) adjusting the bounding box vertices to the intersection point of a ray or the nearest point on the mesh
(using negative normal vectors for the top and bottom faces of the bounding mesh, otherwise using the nearest points on the mesh).
After all processing steps, the data is saved in *.npz format with the following data: xyz coordinates, sdf values, latent shape, faces (latent shape face indices), and class number (shape embedding number).
The code for processing the data is available
here.
Latent shape computation
of 100b18376b885f206ae9ad7e32c4139d
Model Architecture
The training target to optimize consists of the
SDFDecoder and the LatentShapes.
The SDFDecoder follows the architecture used in DeepSDF, where the initial input data is skip-connected to each block input.
The number of blocks and the number of layers in each block can be adjusted in config.py.
The
LatentShapes class is defined as follows. The embedding attribute of LatentShapes is initialized with pre-computed latent shapes with added noise (in this project, the noise values are set to
noise_min = -0.1 and
noise_max = 0.1).
In the forward pass, it returns the embeddings corresponding to the given classes.
class LatentShapes(nn.Module):
def __init__(
self, latent_shapes: torch.Tensor, noise_min: Optional[float] = None, noise_max: Optional[float] = None
):
super().__init__()
self.noise = torch.zeros_like(latent_shapes)
if None not in (noise_min, noise_max):
# add noise to latent shapes
self.noise = noise_min + torch.rand_like(latent_shapes) * (noise_max - noise_min)
# check if the noise range is valid
assert torch.all(self.noise >= noise_min)
assert torch.all(self.noise <= noise_max)
# initialize latent shapes with noise to multi-vector embeddings
self.embedding = nn.Parameter(latent_shapes + self.noise)
def forward(self, class_number: torch.Tensor) -> torch.Tensor:
return self.embedding[class_number]
LatentShapes with multi-vector embeddings
The diagram below shows the overall training process of the models.
Each point at the upper left in the figure represents a 3D \(\text{xyz (1, 3)}\) coordinate used as input.
This coordinate is combined with the \(\text{latent shape (1, 98×3)}\), so the initial input for the forward pass is a \(\text{(1, 297)}\) vector.
The SDFDecoder is trained on this combined input to predict an \(\text{sdf}\) value.
Training process
The latent shape at the lower left indicates that the \(\text{latent shape}\) embedding is initialized with noise, as described above.
During training, the latent shapes are gradually updated toward the original shapes, and this approach allows flexible representation and better reconstruction performance even when the latent shapes are not perfectly matched to the originals.
Lastly, there are two losses used for optimization. The first is \( \mathbf{Loss_{sdf}} \),
which computes the difference between the predicted value and the ground truth value.
The second is \( \mathbf{Loss_{shape}} \), which guides the embedding toward the original latent shape.
Through these losses, the LatentShapes and the SDFDecoder are optimized and improved at the same time.
Training & Memory Strategy
In this project, an NVIDIA GeForce RTX 4090 is used to train the models for 120 epochs using only \(50\) chairs of data.
Each data point has \(36^3\) xyz coordinates, so the total number of xyz coordinates is \(36^3 \times 50 = 2332800\).
Since the GPU does not have enough memory to load all of this data at once (OOM),
the data should be loaded only when needed.
The following
code shows the implementation.
First, the
cumulative_length is pre-computed at initialization time based on the total number of samples per data point.
Then, given an index, the corresponding file is loaded,
and the index is mapped to the correct sample within that file using modulo with the total sampling number.
class SDFDataset(Dataset):
(...)
def __len__(self) -> int:
return self.total_length
def __getitem__(self, _idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
for file_idx, cumulative_length in enumerate(self.cumulative_length):
if _idx < cumulative_length:
file_idx -= 1
break
data = np.load(self.data_path[file_idx])
idx = _idx % self.configuration.N_TOTAL_SAMPLING
xyz = torch.tensor(data["xyz"][idx], dtype=torch.float32)
sdf = torch.tensor(data["sdf"][idx], dtype=torch.float32)
class_number = torch.tensor(data["class_number"], dtype=torch.long)
latent_shape = torch.tensor(data["latent_shape"], dtype=torch.float32)
faces = torch.tensor(data["faces"], dtype=torch.long)
xyz = xyz.to(self.configuration.DEVICE)
sdf = sdf.to(self.configuration.DEVICE)
class_number = class_number.to(self.configuration.DEVICE)
latent_shape = latent_shape.to(self.configuration.DEVICE)
faces = faces.to(self.configuration.DEVICE)
return xyz, sdf, class_number, latent_shape, faces
Dataset class
Below is the code used for
training at each epoch.
The variable
cxyz represents the combined input described in the Architecture section above.
The SDFDecoder model takes this combined input and predicts sdf values in batches.
The model's weights are updated based on the
loss that corresponds to \( \mathbf{Loss_{sdf}} \).
Since the LatentShapes are initialized with noise, the
loss_shape that corresponds to \( \mathbf{Loss_{shape}} \) is used to update them
so that they gradually become closer to the original latent shapes.
def _train_each_epoch(self) -> Tuple[float, float]:
"""training method for each epoch
Returns:
Tuple[float, float]: training loss for the decoder, and the latent shape
"""
losses = []
losses_shape = []
iterator_train = tqdm(
enumerate(self.sdf_dataset.train_dataloader), total=len(self.sdf_dataset.train_dataloader)
)
for batch_index, data in iterator_train:
xyz_batch, sdf_batch, class_number_batch, latent_shapes_batch_r, _ = data
latent_shapes_batch = self.latent_shapes(class_number_batch)
latent_shapes_batch = latent_shapes_batch.reshape(latent_shapes_batch.shape[0], -1)
cxyz = torch.cat((xyz_batch, latent_shapes_batch), dim=1)
sdf_preds = self.sdf_decoder(cxyz)
sdf_preds = torch.clamp(sdf_preds, min=-self.configuration.CLAMP, max=self.configuration.CLAMP)
sdf_batch = torch.clamp(sdf_batch, min=-self.configuration.CLAMP, max=self.configuration.CLAMP)
loss_shape = torch.nn.functional.mse_loss(self.latent_shapes(class_number_batch), latent_shapes_batch_r)
loss_shape.backward()
self.latent_shapes_optimizer.step()
self.latent_shapes_optimizer.zero_grad()
loss = torch.nn.functional.l1_loss(sdf_preds, sdf_batch.unsqueeze(-1))
loss = loss / self.configuration.ACCUMULATION_STEPS
loss.backward()
if (batch_index + 1) % self.configuration.ACCUMULATION_STEPS == 0 or (batch_index + 1) == len(
self.sdf_dataset.train_dataloader
):
self.sdf_decoder_optimizer.step()
self.sdf_decoder_optimizer.zero_grad()
losses.append(loss.item() * self.configuration.ACCUMULATION_STEPS)
losses_shape.append(loss_shape.item())
loss_mean = torch.tensor(losses).mean().item()
loss_shape_mean = torch.tensor(losses_shape).mean().item()
return loss_mean, loss_shape_mean
Training iteration
The flag for saving the decoder model's state uses the
weighted sum of the training loss and the validation loss.
The weights are set to 0.1 and 1.0, respectively, so that the model is saved with a focus on validation generalization performance,
while a small weight on the training loss ensures that the model can also reconstruct the training data well.
The following are the loss graphs per epoch. All of the losses
loss_mean,
loss_mean_val,
loss_shape_mean, and
loss_shape_mean_val show stable convergence.
In the next chapter, the reconstruction performance is evaluated qualitatively.
Losses for 120 epochs
Reconstruction & Linear Interpolation
The following figure shows the results of reconstructing 50 trained chairs using latent shape embeddings and the decoder (the marching cubes algorithm was applied at a resolution of \(192^3\)).
Details such as the legs and armrests of the chairs are properly reconstructed.
Reconstructed chairs at 192 resolution
Next, several results generated using the trained shape embeddings through ratio-based linear interpolation are examined.
The data on the left and right represent the input shapes to be interpolated, while the center shows the interpolated result according to the specified ratios.
The following ratios were used: \( \text{[[0.50, 0.50], [0.30, 0.70], [0.60, 0.40], [0.45, 0.55]]} \).
Although this method works well for interpolation between shapes, it is still difficult to predict the specific changes when only part of the latent vector is modified.
In particular, it is not easy to determine which components of the vector correspond to specific geometric parts.
Interpolated chairs
Instead of relying on these static interpolation methods,
the next section examines the core implementation of the interface that allows interactive manipulation of the latent vector.
Interactive Manipulation UI
Here, I briefly describe the implementation of the demo GIF shown in the introduction section.
First,
FastAPI is used as the communication layer connecting the interface with the model.
The API returns results in JSON format, where both the reconstruction outputs and the latent shape data are represented as mesh data based on vertices and faces.
The reconstruct API request is defined as follows.
Due to the Marching Cubes algorithm, there is a trade-off between reconstruction quality and reconstruction speed depending on the resolution.
Therefore, this parameter is made adjustable by the user to balance fidelity and performance.
class ReconstructRequest(BaseModel):
latent_shapes: List[List[float]]
rescale: bool
map_z_to_y: bool
ensure_watertight: bool
resolution: int
(...)
@app.post("/api/reconstruct")
def reconstruct(request: ReconstructRequest):
try:
configuration.RECONSTRUCTION_GRID_SIZE = request.resolution
latent_shapes_tensor = torch.tensor(request.latent_shapes).to(configuration.DEVICE)
# map z to y to match the loaded latent shape into the xyz system
latent_shapes_tensor[:, [1, 2]] = latent_shapes_tensor[:, [2, 1]]
reconstruction_results = sdf_decoder.reconstruct(
latent_shapes=latent_shapes_tensor.unsqueeze(0),
save_path=os.path.join(os.path.dirname(__file__)),
check_watertight=request.ensure_watertight,
map_z_to_y=request.map_z_to_y,
add_noise=False,
rescale=request.rescale,
)
if reconstruction_results[0] is None:
raise HTTPException(status_code=400, detail="Reconstruction failed")
# Extract mesh data from the first result
mesh = reconstruction_results[0]
vertices = mesh.vertices.tolist()
faces = mesh.faces.tolist()
edges = mesh.edges.tolist()
return {
"message": "Reconstruction successful",
"vertices": vertices,
"faces": faces,
"edges": edges,
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Reconstruction failed: {str(e)}")
Reconstruction API
Finally, the last part is the construction of the Three.js-based interface.
The latent shapes are generated by shrink-wrapping a subdivided mesh box for the chair data, resulting in a mesh that can be directly rendered on a Three.js canvas.
The main objective is to allow users to manipulate the vertices of this mesh, thereby creating an
intuitive and interactive interpolation experience.
Three.js SelectionBox and TransformControls are the main modules used in this implementation, enabling precise vertex selection and transformation within the interface.
The key logic of the SelectionBox lies in converting the browser's mouse coordinates to
Normalized Device Coordinates (NDC).
While the browser uses a coordinate system ranging from 0 to 1, Three.js requires coordinates in the -1 to 1 range for Raycaster or SelectionBox operations.
In the browser, the y-coordinate increases toward the bottom of the screen, whereas in NDC the y-coordinate increases toward the top. Therefore, the sign of the y-coordinate must be flipped during the conversion.
function on_pointerup(event) {
is_mouse_down = false;
if (!is_selection_mode ) {
return;
}
selection_box.endPoint.set(
(event.clientX / window.innerWidth) * 2 - 1,
-(event.clientY / window.innerHeight) * 2 + 1,
0.5
);
const selections = selection_box.select();
for (const selected_index of selected_indices) {
latent_shape[selected_index].material.color.set(LATENT_SHAPE_SPHERE_COLOR);
}
selected_shape = [];
selected_indices = [];
for (const selection of selections) {
if (selection.name === LATENT_SHAPE) {
let selected_index = -1;
selected_shape.push(selection)
selection.material.color.set(LATENT_SHAPE_SPHERE_SELECTED_COLOR);
for (let i = 0; i < latent_shape.length; i++) {
const vertex = latent_shape[i];
if (
Math.abs(selection.position.x - vertex.position.x) <= 0.001
&& Math.abs(selection.position.y - vertex.position.y) <= 0.001
&& Math.abs(selection.position.z - vertex.position.z) <= 0.001
) {
selected_index = i;
break;
}
}
if (0 <= selected_index && selected_index < latent_shape.length) {
selected_indices.push(selected_index);
}
}
}
if (selected_shape.length > 0) {
first_selected_latent_shape = {
position: selected_shape[0].position.clone(),
relative_positions: selected_shape.map(shape => shape.position.clone())
};
transform_controls.attach(selected_shape[0]);
} else {
transform_controls.detach();
}
}
SelectionBox & TransformControls
Because TransformControls in Three.js can only be attached to one object at a time,
this implementation attaches the controls to the first selected object.
When the user moves this object, the same movement offset
(dx, dy, dz) is applied to the other selected objects so that their relative positions are preserved.
This approach allows users to interactively manipulate multiple vertices at once, even though TransformControls is attached to only one object.
function on_object_change(event) {
if (selected_shape.length === 0 || !first_selected_latent_shape) {
return;
}
const dx = selected_shape[0].position.x - first_selected_latent_shape.position.x;
const dy = selected_shape[0].position.y - first_selected_latent_shape.position.y;
const dz = selected_shape[0].position.z - first_selected_latent_shape.position.z;
for (let i = 1; i < selected_shape.length; i++) {
const original_pos = first_selected_latent_shape.relative_positions[i];
selected_shape[i].position.x = original_pos.x + dx;
selected_shape[i].position.y = original_pos.y + dy;
selected_shape[i].position.z = original_pos.z + dz;
}
const latent_shape_vertices = latent_shape.map(
child => [child.position.x, child.position.y, child.position.z]
);
scene.remove(latent_shape_wireframe);
latent_shape_wireframe = create_mesh(
latent_shape_vertices,
latent_shape_faces,
LATENT_SHAPE_WIREFRAME_MATERIAL,
);
scene.add(latent_shape_wireframe);
}
Updating LatentShapes
This project can be explored in detail in the linked
repository.
I conclude this post by sharing a demo GIF illustrating the interactive latent shape manipulation.
Latent Shapes demo
References