import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from shapely import geometry, affinity
from IPython.display import HTML
class DistanceOverUnionWithShapely(torch.autograd.Function):
@staticmethod
def forward(ctx, parameters, target_rect, save=True):
"""compute diou loss using shapely geometric operations
"""
x, y, w, h, t = parameters.detach().numpy()
rect = geometry.box(-w / 2, -h / 2, w / 2, h / 2)
rect = affinity.translate(rect, xoff=x, yoff=y)
rect = affinity.rotate(rect, t, use_radians=True)
intersection = rect.intersection(target_rect)
union = rect.union(target_rect)
iou = intersection.area / union.area
distance = rect.centroid.distance(target_rect.centroid)
bbox = geometry.GeometryCollection([target_rect, rect]).envelope
diag = geometry.Point(bbox.bounds[:2]).distance(geometry.Point(bbox.bounds[2:]))
loss = 1 - iou + (distance ** 2) / (diag ** 2)
if save:
ctx.save_for_backward(parameters)
ctx.target_rect = target_rect
return torch.tensor(loss, dtype=torch.float32)
@staticmethod
def backward(ctx, grad_output):
"""compute gradients by central numerical differentiation
"""
parameters, = ctx.saved_tensors
target_rect = ctx.target_rect
eps = 1e-4
grads = []
for i in range(len(parameters)):
delta = torch.zeros_like(parameters)
delta[i] = eps
perturbed_params_pos = parameters + delta
grad_pos = DistanceOverUnionWithShapely.forward(
ctx,
perturbed_params_pos,
target_rect,
save=False
)
perturbed_params_neg = parameters - delta
grad_neg = DistanceOverUnionWithShapely.forward(
ctx,
perturbed_params_neg,
target_rect,
save=False
)
grad = (grad_pos - grad_neg) / (2 * eps)
grads.append(grad)
grads = torch.tensor(grads, dtype=torch.float32)
return grads * grad_output, None
parameters = torch.tensor([-3, -3, 2.0, 1.0, 0], requires_grad=True)
optimizer = torch.optim.Adam([parameters], lr=0.01)
target_rect = geometry.box(-2, -2, 1, 2)
target_rect = affinity.translate(target_rect, *(2, 2))
target_rect = affinity.rotate(target_rect, 45, use_radians=False)
losses = []
rects = []
for epoch in range(1000):
x, y, w, h, t = parameters.detach().numpy()
rect = geometry.box(-w / 2, -h / 2, w / 2, h / 2)
rect = affinity.translate(rect, xoff=x, yoff=y)
rect = affinity.rotate(rect, t, use_radians=True)
optimizer.zero_grad()
loss = DistanceOverUnionWithShapely.apply(parameters, target_rect)
loss.backward()
optimizer.step()
losses.append(loss.item())
rects.append(rect)
if loss.item() < 1e-3:
break
fig, ax = plt.subplots()
def animate(frame):
rect = rects[frame]
distance = geometry.LineString([rect.centroid, target_rect.centroid])
ax.clear()
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 10)
ax.grid(True, alpha=0.3)
ax.plot(*target_rect.exterior.xy, color="green", label="ground-truth", linewidth=1)
ax.plot(*rect.exterior.xy, color="black", label="predicted", linewidth=1)
ax.plot(*distance.xy, color="red", label="distance", linewidth=1, linestyle="dotted")
ax.set_aspect("equal")
ax.legend(loc="upper left")
ax.set_title(f"Epoch: {frame + 1}, DIoU: {losses[frame]:.5f} \n", fontsize=9)
return ()
anim = animation.FuncAnimation(fig, animate, frames=len(rects), interval=50, blit=True, repeat=False)
plt.close(fig)
HTML(anim.to_jshtml())