parkcheolhee-lab

Polygon segmentations

Objective ✂️

In the early stages of architectural design, there is a concept about the axis in which direction the building will be placed. This plays an important role in determining the optimal layout considering the functionality, aesthetics, and environmental conditions of the building. The goal of this project is to develop a segmenter that can determine how many axes a given 2D polygon should be segmented into and how to make those segmentations.

Segmented lands by a human architect


The figures above are imaginary dividing lines arbitrarily set by the architect. The apartments in the figures are placed based on segmented polygons. Human architects intuitively know how many axes a given 2D polygon should be segmented into and how to create these segmentations, but explaining this intuition to a computer is difficult.

To achieve this, I will use a combination of deep learning and graph theory. In the graph, each point of the polygon will be a node and the connections between points will be edges. Based on this concept, I will implement a GNN-based model, which will learn how to optimally segment given polygons.


A simple understanding of graph and GNN

Before generating data, let's understand graphs and Graph Neural Networks. Basically, the graph is defined \( G = (V, E) \). At this expression, \( V \) is the set of vertices (nodes) and \( E \) is the set of edges.

Graphs are mainly expressed as an adjacency matrix. When the number of points is \( n \), the size of the adjacency matrix \( A \) is \( n \times n \). When dealing with a graph in machine learning, it is expressed as a feature matrix depicting the characteristics of points. When the number of features is \( f \), the dimension of the feature matrix \( X \) is \( n \times f \).
Understanding of the graph expression
In this figure, \( n = 4 \), \( f = 3 \)
\( A = n \times n = 4 \times 4 \)
\( X = n \times f = 4 \times 3 \).


Graphs are used in various fields to represent data, and they are useful when teaching geometry to deep learning models for the following reasons:

Graph Neural Networks (GNN) is a type of neural network designed to operate on graph structures. Unlike traditional neural networks, which work on fixed-size inputs like vectors or matrices, GNNs can handle graph-structured data with variable size and connectivity. This makes GNNs particularly suitable for tasks where relationships are important, such as social networks, geometries, etc. It primarily use connections and the states of neighboring nodes to update (learn) the state of each node (Message Passing). Predictions are then made based on the final states. This final state is generally referred to as the node embedding (or I think also encoding is right because raw features of nodes are changed to other representations).

There are various methods for Message Passing, but since this task will be dealing with geometry, let's focus on models based on Spatial Convolutionan Network. This method is known to be suitable for representing data with important geometric and spatial structures. GNNs using Spatial Convolutional Network enable each node to integrate information from neighboring nodes, allowing for a more accurate understanding of local characteristics. Through this, the model can better comprehend the complex shapes and features of geometry.
Convolution operations
From the left, 2D Convolution · Graph Convolution


The idea of a Spatial Convolutional Network (SCN) is similar to that of a Convolutional Neural Network (CNN). An image can be transformed into a grid-shaped graph.

CNN processes images by using filters to combine the surrounding pixels of a central pixel. SCN extends this idea to graph structures by combining the features of neighboring connected nodes instead of neighboring pixels. Specifically, CNNs are useful for processing images in a regular grid structure, where the filter considers the surrounding area of each pixel to extract features. In contrast, SCNs operate on general graph structures, combining the features of each node with those of its connected neighbors to generate embeddings.


Data preparation

Since I have briefly looked into graphs and GNNs in the above, let's now prepare the data! There are already many raw polygons around us, and that is the land. First, let's collect raw polygons from vworld. Below is a part of all the raw polygons I collected from vworld.
Somewhere in Seoul.shp


Now that we have collected the raw polygons, let's define the characteristics of the polygons to be included in the feature matrix. They are as follows:

Then, I need to convert these data into graph form. Here is an example of hand labeling:
Land polygon in Gangbukgu, Seoul
From the left, raw polygon & labeled link · adjacency matrix · feature matrix


However since it is impossible to label countless raw polygons by hand, I need to generate this data automatically. It would be nice if it could be fully automated with an algorithm, but if that were possible, there wouldn't be a need for deep learning 🤔. So, I created a naive algorithm that can reduce manual work even a little. This algorithm is inspired by the triangulations of the polygon. The process of this algorithm is as follows:

The scores (even_area_score, ombr_ratio_score, slope_similarity_score) used in the fifth step are computed as follows, and each score is aggregated as a weighted sum to obtain the segmentations with the lowest score.

\[ even\_area\_score = \frac{(A_1 - \sum_{i=2}^{n} A_i)}{A_{polygon}} \times {w_1} \]


\[ ombr\_ratio\_score = |(1 - \frac{A_{split1}}{A_{ombr1}}) - \sum_{i=2}^{n} (1 - \frac{A_{spliti}}{A_{ombri}})| \times {w_2} \]


\[ avg\_slope\_similarity_i = \frac{\sum_{j=1}^{k_i} |\text{slope}_j - \text{slope}_{\text{main}}|}{k_i} \]


\[ slope\_similarity\_score = \frac{\sum_{i=1}^{n} avg\_slope\_similarity_i}{n} \times {w_3} \]



\[ score = even\_area\_score + ombr\_ratio\_score + slope\_similarity\_score \]



The whole code for this algorithm can be found here, and the key part of the algorithm is as follows.

    def segment_polygon(
        self,
        polygon: Polygon,
        number_to_split: int,
        segment_threshold_length: float,
        even_area_weight: float,
        ombr_ratio_weight: float,
        slope_similarity_weight: float,
        return_splitters_only: bool = True,
    ):
        """Segment a given polygon

        Args:
            polygon (Polygon): polygon to segment
            number_to_split (int): number of splits to segment
            segment_threshold_length (float): segment threshold length
            even_area_weight (float): even area weight
            ombr_ratio_weight (float): ombr ratio weight
            slope_similarity_weight (float): slope similarity weight
            return_splitters_only (bool, optional): return splitters only. Defaults to True.

        Returns:
            Tuple[List[Polygon], List[LineString], List[LineString]]: splits, triangulations edges, splitters
        """

        _, triangulations_edges = self.triangulate_polygon(
            polygon=polygon,
            segment_threshold_length=segment_threshold_length,
        )

        splitters_selceted = None
        splits_selected = None
        splits_score = None

        for splitters in list(itertools.combinations(triangulations_edges, number_to_split - 1)):
            exterior_with_splitters = ops.unary_union(list(splitters) + self.explode_polygon(polygon))

            exterior_with_splitters = shapely.set_precision(
                exterior_with_splitters, self.TOLERANCE_LARGE, mode="valid_output"
            )

            exterior_with_splitters = ops.unary_union(exterior_with_splitters)

            splits = list(ops.polygonize(exterior_with_splitters))

            if len(splits) != number_to_split:
                continue

            if any(split.area < polygon.area * 0.25 for split in splits):
                continue

            is_acute_angle_in = False
            is_triangle_shape_in = False
            for split in splits:
                split_segments = self.explode_polygon(split)
                splitter_indices = []

                for ssi, split_segment in enumerate(split_segments):
                    if split_segment.length <= self.TOLERANCE_LARGE * 2:
                        continue

                    reduced_split_segment = DataCreatorHelper.extend_linestring(
                        split_segment, -self.TOLERANCE_LARGE, -self.TOLERANCE_LARGE
                    )
                    buffered_split_segment = reduced_split_segment.buffer(self.TOLERANCE, cap_style=CAP_STYLE.flat)

                    if buffered_split_segment.intersects(MultiLineString(splitters)):
                        splitter_indices.append(ssi)
                        splitter_indices.append(ssi + 1)

                degrees = self.compute_polyon_inner_degrees(split)
                degrees += [degrees[0]]

                if (np.array(degrees)[splitter_indices] < 20).sum():
                    is_acute_angle_in = True
                    break

                if len(self.explode_polygon(self.simplify_polygon(split))) == 3:
                    is_triangle_shape_in = True
                    break

            if is_acute_angle_in or is_triangle_shape_in:
                continue

            sorted_splits_area = sorted([split.area for split in splits], reverse=True)
            even_area_score = (sorted_splits_area[0] - sum(sorted_splits_area[1:])) / polygon.area * even_area_weight

            ombr_ratio_scores = []
            slope_similarity_scores = []

            for split in splits:
                ombr = split.minimum_rotated_rectangle
                each_ombr_ratio = split.area / ombr.area
                inverted_ombr_score = 1 - each_ombr_ratio
                ombr_ratio_scores.append(inverted_ombr_score)

                slopes = []
                for splitter in splitters:
                    if split.buffer(self.TOLERANCE_LARGE).intersects(splitter):
                        slopes.append(self.compute_slope(splitter.coords[0], splitter.coords[1]))

                splitter_main_slope = max(slopes, key=abs)

                split_slopes_similarity = []
                split_segments = self.explode_polygon(split)
                for split_seg in split_segments:
                    split_seg_slope = self.compute_slope(split_seg.coords[0], split_seg.coords[1])
                    split_slopes_similarity.append(abs(splitter_main_slope - split_seg_slope))

                avg_slope_similarity = sum(split_slopes_similarity) / len(split_slopes_similarity)
                slope_similarity_scores.append(avg_slope_similarity)

            ombr_ratio_score = abs(ombr_ratio_scores[0] - sum(ombr_ratio_scores[1:])) * ombr_ratio_weight
            slope_similarity_score = sum(slope_similarity_scores) / len(splits) * slope_similarity_weight

            score_sum = even_area_score + ombr_ratio_score + slope_similarity_score

            if splits_score is None or splits_score > score_sum:
                splits_score = score_sum
                splits_selected = splits
                splitters_selceted = splitters

        if return_splitters_only:
            return splitters_selceted

        return splits_selected, triangulations_edges, splitters_selceted


However, this algorithm is not perfect, and there are some problems. Because this algorithm uses the weights for computing scores, it may be sensitive to them. Look at the below figures. The first is good case, and the second is bad case.
Results of the algorithm
From the left, triangulations · segmentations · oriented bounding boxes for segmentations


Since these problems cannot be handled by my naive algorithm, I first used the algorithm to process the raw polygon data and then labeled it manually like the following. So now I have approximately 40000 data with augmented originals. The all dataset can be found here.
Some data for training

Model for link prediction

Now, let's create a graph model using Pytorch Geometric and teach the graph data. Pytorch Geometric is a library based on PyTorch to easily write and train graph neural networks.

The role I need to assign to the model is to generate lines that will segment polygons. This can be translated into a task primarily used in GNNs, known as link prediction. Link prediction models usually use an encoder-decoder structure. The encoder creates node embeddings, which are vector representations of the nodes that extract their features. The decoder then uses these embeddings to predict the probability that a pair of nodes is connected.

When inference, the model inputs all possible node pairs into the decoder. It then calculates the probability of each pair being connected. Only pairs with probabilities above a certain threshold are kept, indicating likely connections.

Generally, a simple operation like the dot product is used to predict links based on the similarity of node pairs. However, I thought this approach was not suitable for tasks using geometric data, so I additionally trained a decoder. Below are the encode and decode methods of the model. The complete model code can be found here.

    class PolygonSegmenter(nn.Module):
        def __init__(
            self,
            conv_type: str,
            in_channels: int,
            hidden_channels: int,
            out_channels: int,
            encoder_activation: nn.Module,
            decoder_activation: nn.Module,
            predictor_activation: nn.Module,
            use_skip_connection: bool = Configuration.USE_SKIP_CONNECTION,
        ):
                
            ( ... )

        def encode(self, data: Batch) -> torch.Tensor:
            """Encode the features of polygon graphs

            Args:
                data (Batch): graph batch

            Returns:
                torch.Tensor: Encoded features
            """

            encoded = self.encoder(data.x, data.edge_index, edge_weight=data.edge_weight)

            return encoded

        def decode(self, data: Batch, encoded: torch.Tensor, edge_label_index: torch.Tensor) -> torch.Tensor:
            """Decode the encoded features of the nodes to predict whether the edges are connected

            Args:
                data (Batch): graph batch
                encoded (torch.Tensor): Encoded features
                edge_label_index (torch.Tensor): indices labels

            Returns:
                torch.Tensor: whether the edges are connected
            """

            # Merge raw features and encoded features to inject geometric informations
            if self.use_skip_connection:
                encoded = torch.cat([data.x, encoded], dim=1)

            decoded = self.decoder(torch.cat([encoded[edge_label_index[0]], encoded[edge_label_index[1]]], dim=1)).squeeze()

            return decoded

The encoding process transforms the original feature matrix \( F \) into a new matrix \( E \) with potentially different dimensions. The following is an expression of the feature matrix before and after the encode function.
\[ F = \begin{bmatrix} f_{11} & f_{12} & \cdots & f_{1m} \\ f_{21} & f_{22} & \cdots & f_{2m} \\ \vdots & \vdots & \ddots & \vdots \\ f_{n1} & f_{n2} & \cdots & f_{nm} \end{bmatrix} \]
\[ E = \begin{bmatrix} e_{11} & e_{12} & \cdots & e_{1c} \\ e_{21} & e_{22} & \cdots & e_{2c} \\ \vdots & \vdots & \ddots & \vdots \\ e_{n1} & e_{n2} & \cdots & e_{nc} \end{bmatrix} \]


In the decode method, the encoded features of the nodes and raw features are used to predict the connections (links) between them. This skip connection merges raw features and encoded features to inject geometric information. Using skip connections helps preserve original features, and enhances overall model performance by combining low-level and high-level information.

In the feedforward process, the decoder is trained with connected labels and labels that should not be connected. This is a technique called negative sampling, which is used to improve model performance in link prediction tasks. By providing examples of what should not be linked, negative sampling helps the model better distinguish between actual links and non-links, leading to improved accuracy in predicting future or missing links.

In most networks, actual links are significantly fewer than non-links, which can bias the model towards predicting no link. Negative sampling allows for controlled selection of negative examples, balancing the training data and enhancing the learning process.

        def forward(self, data: Batch) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            """Forward method of the models, segmenter and predictor

            Args:
                data (Batch): graph batch

            Returns:
                Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: whether the edges are connected, predicted k and target k
            """

            # Encode the features of polygon graphs
            encoded = self.encode(data)

            ( ... )

            # Sample negative edges
            negative_edge_index = negative_sampling(
                edge_index=data.edge_label_index_only,
                num_nodes=data.num_nodes,
                num_neg_samples=int(data.edge_label_index_only.shape[1] * Configuration.NEGATIVE_SAMPLE_MULTIPLIER),
                method="sparse",
            )

            # Decode the encoded features of the nodes to predict whether the edges are connected
            decoded = self.decode(data, encoded, torch.hstack([data.edge_label_index_only, negative_edge_index]).int())

            return decoded, k_predictions, k_targets


Up to this point, I defined the encoder and decoder. Since the encoder and decoder use nodes in batches, it seems that they cannot recognize each graph separately. Because I wanted to train the model on how many segmentations to divide the graph into when a graph is input, I defined a predictor to train k separately within the segmenter class.

The segmenter generates the segmentations using topk through the encoder and decoder processes described above. It then sorts the generated links in order of connection strength, and the predictor decides how many links to use.
Inference process
From the top, topk segmentations · segmentation selected by predictor


Training and evaluating

It's time to train the model. The model has been trained for 500 epochs, during which both the training loss and validation loss were recorded to monitor the training progress and convergence. As shown in the plots:
Losses and metrics for 500 epochs


All the metrics look good, but there is a question about whether these metrics can be trusted 100%. This may be due to the impact of negative sampling.

Based on the visualization of some of the test data, it is evident that the model accurately predicts polygons that do not require segmenting. This suggests that the model may be overfitting on negative samples. High performance on negative samples might not accurately reflect the model's ability to identify positive cases correctly.
Evaluate some test data qualitatively


Taking inspiration from IoU loss, therefore I defined GeometricLoss to evaluate segmentation quality and create a reliable evaluation metric. The GeometricLoss aims to quantify the discrepancy between predicted and ground-truth geometric structures within a graph. The process of the geometric loss calculation is as follows:
An example for the geometric loss calculation
From the left, loss: -0.000478528 · loss: -0.999999994


The geometric loss serves solely as an evaluation metric for the model. Therefore, it has not been added to the BCE loss. This is because the model's training batches are based on nodes rather than graphs. Hence, calculating this loss for every graph per epoch would significantly slow down the process. Therefore, I have calculated this loss only for 4 samples per batch to use it exclusively as a model evaluation metric. The results are as follows:
Geometric losses for 500 epochs
From the left, train geometric loss · validation geometric loss


The GeometricLoss class is defined as follows and the code can be found here.


Limitations and future works

GNNs are node embedding-based models, so they seem to recognize the characteristics of individual nodes rather than the overall shape of the graph. While GNNs have a good ability to generalize shapes during training, it has been challenging to overfit them accurately to the intended labels.


References