latentspace.

Latent Shapes

Objectives

Latent vectors are typically initialized from a normal distribution and then updated during training to minimize the loss function. Although linear interpolation can be used between latent vectors, and similar vectors can be placed in spatially close locations, the latent vectors themselves do not have a specific geometric shape or structure.

This project takes inspiration from this perspective and explores two main ideas: 1) Latent vectors with geometric shape ─ exploring what happens if latent vectors are given a geometric structure, and how this can help us better understand, interpret, or control the results; 2) Interactive manipulation using the latent vector ─ making it possible for users to modify latent vectors easily, so they can immediately see the effects and create new shapes or designs in an interactive way.

Latent Vector Manipulation in real-time with Mouse Dragging
Latent Vector Manipulation in real-time
with Mouse Dragging


Data Preprocessing

This project uses the chair dataset (03001627) from ShapeNetCore. The original data is available here. First, I examine the meshes in ShapeNetCore. The dataset contains 6778 data points, and all meshes are already scaled to small values, but are not zero-centered. In addition, some meshes are not aligned with the AABB axis. Therefore, all meshes must go through the following steps before training. In this project, the DeepSDF model architecture is used as the mesh reconstruction network, so all meshes must also be watertight.
  1. watertight: each mesh is made watertight to ensure correct computation of SDF values.
  2. centralize: each mesh is translated to be zero-centered.
  3. orient: each mesh is aligned to the AABB axis.
  4. sdf values compute: the resulting SDF values are used to train a model for mesh reconstruction.


Data processing steps of 1007e20d5e811b308351982a6e40cf41
Data processing steps
of 1007e20d5e811b308351982a6e40cf41


Now, as the final step, the latent shapes are computed. The process consists of the following steps: 1) creating a bounding box of the mesh; 2) applying a scale matrix to the bounding box; 3) subdividing the bounding box n times; 4) adjusting the bounding box vertices to the intersection point of a ray or the nearest point on the mesh (using negative normal vectors for the top and bottom faces of the bounding mesh, otherwise using the nearest points on the mesh).

After all processing steps, the data is saved in *.npz format with the following data: xyz coordinates, sdf values, latent shape, faces (latent shape face indices), and class number (shape embedding number). The code for processing the data is available here.

Latent shape computation of 100b18376b885f206ae9ad7e32c4139d
Latent shape computation
of 100b18376b885f206ae9ad7e32c4139d


Model Architecture

The training target to optimize consists of the SDFDecoder and the LatentShapes. The SDFDecoder follows the architecture used in DeepSDF, where the initial input data is skip-connected to each block input. The number of blocks and the number of layers in each block can be adjusted in config.py.

The LatentShapes class is defined as follows. The embedding attribute of LatentShapes is initialized with pre-computed latent shapes with added noise (in this project, the noise values are set to noise_min = -0.1 and noise_max = 0.1). In the forward pass, it returns the embeddings corresponding to the given classes.

    class LatentShapes(nn.Module):
        def __init__(
            self, latent_shapes: torch.Tensor, noise_min: Optional[float] = None, noise_max: Optional[float] = None
        ):
            super().__init__()

            self.noise = torch.zeros_like(latent_shapes)
            if None not in (noise_min, noise_max):
                # add noise to latent shapes
                self.noise = noise_min + torch.rand_like(latent_shapes) * (noise_max - noise_min)

                # check if the noise range is valid
                assert torch.all(self.noise >= noise_min)
                assert torch.all(self.noise <= noise_max)

            # initialize latent shapes with noise to multi-vector embeddings
            self.embedding = nn.Parameter(latent_shapes + self.noise)

        def forward(self, class_number: torch.Tensor) -> torch.Tensor:
            return self.embedding[class_number]
LatentShapes with multi-vector embeddings


The diagram below shows the overall training process of the models. Each point at the upper left in the figure represents a 3D \(\text{xyz (1, 3)}\) coordinate used as input. This coordinate is combined with the \(\text{latent shape (1, 98×3)}\), so the initial input for the forward pass is a \(\text{(1, 297)}\) vector. The SDFDecoder is trained on this combined input to predict an \(\text{sdf}\) value.

Training process
Training process


The latent shape at the lower left indicates that the \(\text{latent shape}\) embedding is initialized with noise, as described above. During training, the latent shapes are gradually updated toward the original shapes, and this approach allows flexible representation and better reconstruction performance even when the latent shapes are not perfectly matched to the originals.

Lastly, there are two losses used for optimization. The first is \( \mathbf{Loss_{sdf}} \), which computes the difference between the predicted value and the ground truth value. The second is \( \mathbf{Loss_{shape}} \), which guides the embedding toward the original latent shape. Through these losses, the LatentShapes and the SDFDecoder are optimized and improved at the same time.


Training & Memory Strategy

In this project, an NVIDIA GeForce RTX 4090 is used to train the models for 120 epochs using only \(50\) chairs of data. Each data point has \(36^3\) xyz coordinates, so the total number of xyz coordinates is \(36^3 \times 50 = 2332800\). Since the GPU does not have enough memory to load all of this data at once (OOM), the data should be loaded only when needed.

The following code shows the implementation. First, the cumulative_length is pre-computed at initialization time based on the total number of samples per data point. Then, given an index, the corresponding file is loaded, and the index is mapped to the correct sample within that file using modulo with the total sampling number.

    class SDFDataset(Dataset):

        (...)

        def __len__(self) -> int:
            return self.total_length

        def __getitem__(self, _idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            for file_idx, cumulative_length in enumerate(self.cumulative_length):
                if _idx < cumulative_length:
                    file_idx -= 1
                    break

            data = np.load(self.data_path[file_idx])

            idx = _idx % self.configuration.N_TOTAL_SAMPLING

            xyz = torch.tensor(data["xyz"][idx], dtype=torch.float32)
            sdf = torch.tensor(data["sdf"][idx], dtype=torch.float32)
            class_number = torch.tensor(data["class_number"], dtype=torch.long)
            latent_shape = torch.tensor(data["latent_shape"], dtype=torch.float32)
            faces = torch.tensor(data["faces"], dtype=torch.long)

            xyz = xyz.to(self.configuration.DEVICE)
            sdf = sdf.to(self.configuration.DEVICE)
            class_number = class_number.to(self.configuration.DEVICE)
            latent_shape = latent_shape.to(self.configuration.DEVICE)
            faces = faces.to(self.configuration.DEVICE)

            return xyz, sdf, class_number, latent_shape, faces
Dataset class


Below is the code used for training at each epoch. The variable cxyz represents the combined input described in the Architecture section above. The SDFDecoder model takes this combined input and predicts sdf values in batches. The model's weights are updated based on the loss that corresponds to \( \mathbf{Loss_{sdf}} \). Since the LatentShapes are initialized with noise, the loss_shape that corresponds to \( \mathbf{Loss_{shape}} \) is used to update them so that they gradually become closer to the original latent shapes.

    def _train_each_epoch(self) -> Tuple[float, float]:
        """training method for each epoch

        Returns:
            Tuple[float, float]: training loss for the decoder, and the latent shape
        """

        losses = []
        losses_shape = []

        iterator_train = tqdm(
            enumerate(self.sdf_dataset.train_dataloader), total=len(self.sdf_dataset.train_dataloader)
        )

        for batch_index, data in iterator_train:
            xyz_batch, sdf_batch, class_number_batch, latent_shapes_batch_r, _ = data

            latent_shapes_batch = self.latent_shapes(class_number_batch)
            latent_shapes_batch = latent_shapes_batch.reshape(latent_shapes_batch.shape[0], -1)

            cxyz = torch.cat((xyz_batch, latent_shapes_batch), dim=1)

            sdf_preds = self.sdf_decoder(cxyz)

            sdf_preds = torch.clamp(sdf_preds, min=-self.configuration.CLAMP, max=self.configuration.CLAMP)
            sdf_batch = torch.clamp(sdf_batch, min=-self.configuration.CLAMP, max=self.configuration.CLAMP)

            loss_shape = torch.nn.functional.mse_loss(self.latent_shapes(class_number_batch), latent_shapes_batch_r)

            loss_shape.backward()
            self.latent_shapes_optimizer.step()
            self.latent_shapes_optimizer.zero_grad()

            loss = torch.nn.functional.l1_loss(sdf_preds, sdf_batch.unsqueeze(-1))
            loss = loss / self.configuration.ACCUMULATION_STEPS

            loss.backward()
            if (batch_index + 1) % self.configuration.ACCUMULATION_STEPS == 0 or (batch_index + 1) == len(
                self.sdf_dataset.train_dataloader
            ):
                self.sdf_decoder_optimizer.step()
                self.sdf_decoder_optimizer.zero_grad()

            losses.append(loss.item() * self.configuration.ACCUMULATION_STEPS)
            losses_shape.append(loss_shape.item())

        loss_mean = torch.tensor(losses).mean().item()
        loss_shape_mean = torch.tensor(losses_shape).mean().item()

        return loss_mean, loss_shape_mean
Training iteration


The flag for saving the decoder model's state uses the weighted sum of the training loss and the validation loss. The weights are set to 0.1 and 1.0, respectively, so that the model is saved with a focus on validation generalization performance, while a small weight on the training loss ensures that the model can also reconstruct the training data well.

The following are the loss graphs per epoch. All of the losses loss_mean, loss_mean_val, loss_shape_mean, and loss_shape_mean_val show stable convergence. In the next chapter, the reconstruction performance is evaluated qualitatively.
Losses for 120 epochs
Losses for 120 epochs


Reconstruction & Linear Interpolation

The following figure shows the results of reconstructing 50 trained chairs using latent shape embeddings and the decoder (the marching cubes algorithm was applied at a resolution of \(192^3\)). Details such as the legs and armrests of the chairs are properly reconstructed.
Reconstructed chairs at 192 resolution
Reconstructed chairs at 192 resolution


Next, several results generated using the trained shape embeddings through ratio-based linear interpolation are examined. The data on the left and right represent the input shapes to be interpolated, while the center shows the interpolated result according to the specified ratios. The following ratios were used: \( \text{[[0.50, 0.50], [0.30, 0.70], [0.60, 0.40], [0.45, 0.55]]} \). Although this method works well for interpolation between shapes, it is still difficult to predict the specific changes when only part of the latent vector is modified. In particular, it is not easy to determine which components of the vector correspond to specific geometric parts.
Interpolated chairs
Interpolated chairs


Instead of relying on these static interpolation methods, the next section examines the core implementation of the interface that allows interactive manipulation of the latent vector.


Interactive Manipulation UI

Here, I briefly describe the implementation of the demo GIF shown in the introduction section. First, FastAPI is used as the communication layer connecting the interface with the model. The API returns results in JSON format, where both the reconstruction outputs and the latent shape data are represented as mesh data based on vertices and faces.

The reconstruct API request is defined as follows. Due to the Marching Cubes algorithm, there is a trade-off between reconstruction quality and reconstruction speed depending on the resolution. Therefore, this parameter is made adjustable by the user to balance fidelity and performance.

    class ReconstructRequest(BaseModel):
        latent_shapes: List[List[float]]
        rescale: bool
        map_z_to_y: bool
        ensure_watertight: bool
        resolution: int

    (...)
    
    @app.post("/api/reconstruct")
    def reconstruct(request: ReconstructRequest):
        try:
            configuration.RECONSTRUCTION_GRID_SIZE = request.resolution

            latent_shapes_tensor = torch.tensor(request.latent_shapes).to(configuration.DEVICE)

            # map z to y to match the loaded latent shape into the xyz system
            latent_shapes_tensor[:, [1, 2]] = latent_shapes_tensor[:, [2, 1]]

            reconstruction_results = sdf_decoder.reconstruct(
                latent_shapes=latent_shapes_tensor.unsqueeze(0),
                save_path=os.path.join(os.path.dirname(__file__)),
                check_watertight=request.ensure_watertight,
                map_z_to_y=request.map_z_to_y,
                add_noise=False,
                rescale=request.rescale,
            )

            if reconstruction_results[0] is None:
                raise HTTPException(status_code=400, detail="Reconstruction failed")

            # Extract mesh data from the first result
            mesh = reconstruction_results[0]

            vertices = mesh.vertices.tolist()
            faces = mesh.faces.tolist()
            edges = mesh.edges.tolist()

            return {
                "message": "Reconstruction successful",
                "vertices": vertices,
                "faces": faces,
                "edges": edges,
            }

        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Reconstruction failed: {str(e)}")
Reconstruction API


Finally, the last part is the construction of the Three.js-based interface. The latent shapes are generated by shrink-wrapping a subdivided mesh box for the chair data, resulting in a mesh that can be directly rendered on a Three.js canvas. The main objective is to allow users to manipulate the vertices of this mesh, thereby creating an intuitive and interactive interpolation experience.

Three.js SelectionBox and TransformControls are the main modules used in this implementation, enabling precise vertex selection and transformation within the interface.

The key logic of the SelectionBox lies in converting the browser's mouse coordinates to Normalized Device Coordinates (NDC). While the browser uses a coordinate system ranging from 0 to 1, Three.js requires coordinates in the -1 to 1 range for Raycaster or SelectionBox operations. In the browser, the y-coordinate increases toward the bottom of the screen, whereas in NDC the y-coordinate increases toward the top. Therefore, the sign of the y-coordinate must be flipped during the conversion.


    function on_pointerup(event) {
        is_mouse_down = false;

        if (!is_selection_mode ) {
            return;
        }

        selection_box.endPoint.set(
            (event.clientX / window.innerWidth) * 2 - 1,
            -(event.clientY / window.innerHeight) * 2 + 1,
            0.5
        );

        const selections = selection_box.select();

        for (const selected_index of selected_indices) {
            latent_shape[selected_index].material.color.set(LATENT_SHAPE_SPHERE_COLOR);
        }

        selected_shape = [];
        selected_indices = [];

        for (const selection of selections) {
            if (selection.name === LATENT_SHAPE) {
                let selected_index = -1;
                selected_shape.push(selection)
                selection.material.color.set(LATENT_SHAPE_SPHERE_SELECTED_COLOR);
                for (let i = 0; i < latent_shape.length; i++) {
                    const vertex = latent_shape[i];
                    if (
                        Math.abs(selection.position.x - vertex.position.x) <= 0.001
                        && Math.abs(selection.position.y - vertex.position.y) <= 0.001
                        && Math.abs(selection.position.z - vertex.position.z) <= 0.001
                    ) {
                        selected_index = i;
                        break;
                    }
                }

                if (0 <= selected_index && selected_index < latent_shape.length) {
                    selected_indices.push(selected_index);
                }
            }
        }

        if (selected_shape.length > 0) {
            first_selected_latent_shape = {
                position: selected_shape[0].position.clone(),
                relative_positions: selected_shape.map(shape => shape.position.clone())
            };
            transform_controls.attach(selected_shape[0]);
        } else {
            transform_controls.detach();
        }
    }
SelectionBox & TransformControls


Because TransformControls in Three.js can only be attached to one object at a time, this implementation attaches the controls to the first selected object. When the user moves this object, the same movement offset (dx, dy, dz) is applied to the other selected objects so that their relative positions are preserved. This approach allows users to interactively manipulate multiple vertices at once, even though TransformControls is attached to only one object.


    function on_object_change(event) {
        if (selected_shape.length === 0 || !first_selected_latent_shape) {
            return;
        }

        const dx = selected_shape[0].position.x - first_selected_latent_shape.position.x;
        const dy = selected_shape[0].position.y - first_selected_latent_shape.position.y;
        const dz = selected_shape[0].position.z - first_selected_latent_shape.position.z;

        for (let i = 1; i < selected_shape.length; i++) {
            const original_pos = first_selected_latent_shape.relative_positions[i];
            selected_shape[i].position.x = original_pos.x + dx;
            selected_shape[i].position.y = original_pos.y + dy;
            selected_shape[i].position.z = original_pos.z + dz;
        }

        const latent_shape_vertices = latent_shape.map(
            child => [child.position.x, child.position.y, child.position.z]
        );

        scene.remove(latent_shape_wireframe);

        latent_shape_wireframe = create_mesh(
            latent_shape_vertices,
            latent_shape_faces,
            LATENT_SHAPE_WIREFRAME_MATERIAL,
        );

        scene.add(latent_shape_wireframe);
    }
Updating LatentShapes


This project can be explored in detail in the linked repository. I conclude this post by sharing a demo GIF illustrating the interactive latent shape manipulation.

Latent Shapes demo
Latent Shapes demo


References