This commit is contained in:
Rick Sprague 2024-08-27 18:26:26 -04:00
parent 714b912d3a
commit 27bcd2adac

73
umeyama.py Normal file
View File

@ -0,0 +1,73 @@
import numpy as np
def umeyama(src, dst, estimate_scale=True):
"""Umeyama algorithm to estimate similarity transformation."""
assert src.shape == dst.shape
# Compute the mean of the source and destination points
src_mean = np.mean(src, axis=0)
dst_mean = np.mean(dst, axis=0)
# Subtract the means from the points
src_centered = src - src_mean
dst_centered = dst - dst_mean
# Compute the covariance matrix
cov_matrix = np.dot(dst_centered.T, src_centered) / src.shape[0]
# Singular Value Decomposition
U, D, Vt = np.linalg.svd(cov_matrix)
# Compute the rotation matrix
R = np.dot(U, Vt)
if np.linalg.det(R) < 0:
Vt[-1, :] *= -1
R = np.dot(U, Vt)
# Compute the scale factor
if estimate_scale:
var_src = np.var(src_centered, axis=0).sum()
scale = 1.0 / var_src * np.sum(D)
else:
scale = 1.0
# Compute the translation vector
t = dst_mean - scale * np.dot(R, src_mean)
# Create the transformation matrix
T = np.identity(3)
T[:2, :2] = scale * R
T[:2, 2] = t
return T
# Generate 20 random 2D points
np.random.seed(42) # For reproducibility
src_points = np.random.rand(20, 2)
# Define a known rotation matrix R and translation vector t
theta = np.pi / 4 # 45 degrees rotation
R = np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
t = np.array([1.0, 2.0])
# Apply the transformation to generate the destination points
dst_points = np.dot(src_points, R.T) + t
# Perform Umeyama to estimate the transformation
T = umeyama(src_points, dst_points)
# Apply the resulting transformation to the source points
src_points_hom = np.hstack((src_points, np.ones((src_points.shape[0], 1))))
aligned_points = np.dot(T, src_points_hom.T).T[:, :2]
# Calculate the difference between the destination points and the aligned points
difference = np.linalg.norm(dst_points - aligned_points)
print("Original Source Points:\n", src_points)
print("Transformed Destination Points:\n", dst_points)
print("Recovered Aligned Points:\n", aligned_points)
print("\nDifference between destination and aligned points:", difference)