In [1]:
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from shapely import geometry, affinity
from IPython.display import HTML
In [2]:
class DistanceOverUnionWithShapely(torch.autograd.Function):
    @staticmethod
    def forward(ctx, parameters, target_rect, save=True):
        """compute diou loss using shapely geometric operations
        """
        
        x, y, w, h, t = parameters.detach().numpy()
        
        rect = geometry.box(-w / 2, -h / 2, w / 2, h / 2)
        rect = affinity.translate(rect, xoff=x, yoff=y)
        rect = affinity.rotate(rect, t, use_radians=True)
        
        intersection = rect.intersection(target_rect)
        union = rect.union(target_rect)
        iou = intersection.area / union.area
        distance = rect.centroid.distance(target_rect.centroid)
        bbox = geometry.GeometryCollection([target_rect, rect]).envelope
        diag = geometry.Point(bbox.bounds[:2]).distance(geometry.Point(bbox.bounds[2:]))
        
        loss = 1 - iou + (distance  ** 2) / (diag ** 2) 
        
        if save:
            ctx.save_for_backward(parameters)
            ctx.target_rect = target_rect
        
        return torch.tensor(loss, dtype=torch.float32)
    
    @staticmethod
    def backward(ctx, grad_output):
        """compute gradients by central numerical differentiation
        """
        
        parameters, = ctx.saved_tensors
        target_rect = ctx.target_rect

        eps = 1e-4
        grads = []
        for i in range(len(parameters)):
            delta = torch.zeros_like(parameters)
            delta[i] = eps
            
            perturbed_params_pos = parameters + delta
            grad_pos = DistanceOverUnionWithShapely.forward(
                ctx, 
                perturbed_params_pos, 
                target_rect, 
                save=False
            )
            
            perturbed_params_neg = parameters - delta
            grad_neg = DistanceOverUnionWithShapely.forward(
                ctx, 
                perturbed_params_neg, 
                target_rect, 
                save=False
            )
            
            grad = (grad_pos - grad_neg) / (2 * eps)
            grads.append(grad)
        
        grads = torch.tensor(grads, dtype=torch.float32)
        
        return grads * grad_output, None
In [3]:
parameters = torch.tensor([-3, -3, 2.0, 1.0, 0], requires_grad=True)
optimizer = torch.optim.Adam([parameters], lr=0.01)

target_rect = geometry.box(-2, -2, 1, 2)
target_rect = affinity.translate(target_rect, *(2, 2))
target_rect = affinity.rotate(target_rect, 45, use_radians=False)

losses = []
rects = []

for epoch in range(1000):
    x, y, w, h, t = parameters.detach().numpy()
    rect = geometry.box(-w / 2, -h / 2, w / 2, h / 2)
    rect = affinity.translate(rect, xoff=x, yoff=y)
    rect = affinity.rotate(rect, t, use_radians=True)
    
    optimizer.zero_grad()
    loss = DistanceOverUnionWithShapely.apply(parameters, target_rect)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    rects.append(rect)
    
    if loss.item() < 1e-3:
        break
In [4]:
fig, ax = plt.subplots()

def animate(frame):
    rect = rects[frame]
    distance = geometry.LineString([rect.centroid, target_rect.centroid])
    
    ax.clear()
    ax.set_xlim(-10, 10)
    ax.set_ylim(-10, 10)
    ax.grid(True, alpha=0.3)
    ax.plot(*target_rect.exterior.xy, color="green", label="ground-truth", linewidth=1)
    ax.plot(*rect.exterior.xy, color="black", label="predicted", linewidth=1)
    ax.plot(*distance.xy, color="red", label="distance", linewidth=1, linestyle="dotted")
    ax.set_aspect("equal")
    ax.legend(loc="upper left")
    ax.set_title(f"Epoch: {frame + 1}, DIoU: {losses[frame]:.5f} \n", fontsize=9)
    
    return ()

anim = animation.FuncAnimation(fig, animate, frames=len(rects), interval=50, blit=True, repeat=False)
plt.close(fig)
HTML(anim.to_jshtml())
Out[4]: