Simplified compare algo

This commit is contained in:
Maximilian Giller 2025-09-06 19:11:46 +02:00
parent 337592f5d8
commit 1bd489afc1

View file

@ -7,41 +7,9 @@ def normalize_image(image: np.ndarray) -> np.ndarray:
return cv2.resize(image, (160*2,90*2)) return cv2.resize(image, (160*2,90*2))
def compare_images(image: np.ndarray, target_image: np.ndarray) -> float: def compare_images(image: np.ndarray, target_image: np.ndarray) -> float:
""" err = np.mean((image.astype(float) - target_image.astype(float)) ** 2)
Compare two images pixel-wise, threshold differences, and return a similarity score (0-1). max_val = 255.0 if image.dtype == np.uint8 else 1.0
return 1 - (err / (max_val**2))
Parameters:
image, target_image : np.ndarray
Input images (must be same shape, HxWxC).
Returns:
float : similarity score (1 = identical, 0 = completely different)
"""
THRESHOLD = 0.2 * 255
# Ensure images are uint8
if image.dtype != np.uint8:
image = (255 * image).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
if target_image.dtype != np.uint8:
target_image = (255 * target_image).astype(np.uint8) if target_image.max() <= 1.0 else target_image.astype(np.uint8)
# Ensure same shape
if image.shape != target_image.shape:
raise ValueError(f"Images must have the same shape: {image.shape} vs {target_image.shape}")
# Compute absolute difference per pixel
delta = cv2.absdiff(image, target_image) # shape HxWxC
# If multi-channel, take max across channels
if len(delta.shape) == 3:
delta = delta.max(axis=2) # HxW
# Threshold to binary
binary_diff = (delta > THRESHOLD).astype(np.uint8)
# Score = 1 - fraction of different pixels
score = 1.0 - binary_diff.sum() / binary_diff.size
return score
def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float: def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float:
video = cv2.VideoCapture(path_to_episode) video = cv2.VideoCapture(path_to_episode)
@ -120,20 +88,9 @@ def compare_episode_to_images(path_to_episode: str, images: list[np.ndarray]) ->
def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]: def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]:
results = {} results = {}
# Normalize references
print("Normalizing reference images ...")
for ref in references.keys():
for i in range(len(references[ref])):
references[ref][i] = normalize_image(references[ref][i])
# cv2.imwrite(f"{i}.png", references[ref][i])
# Compare to episodes # Compare to episodes
for episode in episodes: for episode in episodes:
print(f"Processing: {episode}") results[episode] = efficient_episode_matching(episode, references)
results[episode] = {}
for ref, images in references.items():
print(f" -- Reference {ref}")
results[episode][ref] = compare_episode_to_images(episode, images)
print(results[episode]) print(results[episode])
@ -274,33 +231,33 @@ if __name__ == "__main__":
references = { references = {
"1" : [ "1" : [
cv2.imread("./1/9hYDr40VIWbAHGjmlpSkyqlDaVQ.webp"), cv2.imread("./1/9hYDr40VIWbAHGjmlpSkyqlDaVQ.webp"),
# cv2.imread("./1/cmYUHroQ9zRhOYiTf0ozMWe1Bzs.webp"), cv2.imread("./1/cmYUHroQ9zRhOYiTf0ozMWe1Bzs.webp"),
# cv2.imread("./1/yLwaP7Y7O3Wyrht9lPmTDk7LSuG.webp"), cv2.imread("./1/yLwaP7Y7O3Wyrht9lPmTDk7LSuG.webp"),
],
"2" : [
cv2.imread("./2/g2pyyFRFBS18jYByRdiunWr7nap.webp"),
cv2.imread("./2/wKCm3I7efXfsWGNrG28d3cIsL9B.webp"),
cv2.imread("./2/yhL1JBnTDK30BtPbr8gZ5ZXCASL.webp"),
],
"3" : [
cv2.imread("./3/d9F2x01XNhi65GTXwQjrdDJKgqm.webp"),
cv2.imread("./3/kV7VAvLgjjhPTI56OnNvcNaHjDM.webp"),
cv2.imread("./3/mLxkumagHSYg4KNLcEIgHpiJrx.webp"),
],
"4" : [
cv2.imread("./4/9q62zw89SJiRD3kjNuhUIswWHTa.webp"),
cv2.imread("./4/lyi4XUUssCacaEQs2FCGSDCKDUK.webp"),
],
"5" : [
cv2.imread("./5/uIztge10KNlRU4KUFNXWPGCiG8e.webp"),
cv2.imread("./5/vuhdflodFfPkvekVWbS4JZMYHa7.webp"),
cv2.imread("./5/xmvi4bBk5TQyuL6tpiYvNoIPft.webp"),
],
"6" : [
cv2.imread("./6/2I7WgvRwnYDNjbiztvbOAldYGde.webp"),
cv2.imread("./6/dIaBagmUbP1AO6xFQzfQud5lk4I.webp"),
cv2.imread("./6/ekDZrFRDJDyx78uttMyVT9mPhjL.webp"),
], ],
# "2" : [
# cv2.imread("./2/g2pyyFRFBS18jYByRdiunWr7nap.webp"),
# cv2.imread("./2/wKCm3I7efXfsWGNrG28d3cIsL9B.webp"),
# cv2.imread("./2/yhL1JBnTDK30BtPbr8gZ5ZXCASL.webp"),
# ],
# "3" : [
# cv2.imread("./3/d9F2x01XNhi65GTXwQjrdDJKgqm.webp"),
# cv2.imread("./3/kV7VAvLgjjhPTI56OnNvcNaHjDM.webp"),
# cv2.imread("./3/mLxkumagHSYg4KNLcEIgHpiJrx.webp"),
# ],
# "4" : [
# cv2.imread("./4/9q62zw89SJiRD3kjNuhUIswWHTa.webp"),
# cv2.imread("./4/lyi4XUUssCacaEQs2FCGSDCKDUK.webp"),
# ],
# "5" : [
# cv2.imread("./5/uIztge10KNlRU4KUFNXWPGCiG8e.webp"),
# cv2.imread("./5/vuhdflodFfPkvekVWbS4JZMYHa7.webp"),
# cv2.imread("./5/xmvi4bBk5TQyuL6tpiYvNoIPft.webp"),
# ],
# "6" : [
# cv2.imread("./6/2I7WgvRwnYDNjbiztvbOAldYGde.webp"),
# cv2.imread("./6/dIaBagmUbP1AO6xFQzfQud5lk4I.webp"),
# cv2.imread("./6/ekDZrFRDJDyx78uttMyVT9mPhjL.webp"),
# ],
} }
results = match_episodes_to_references(episodes, references) results = match_episodes_to_references(episodes, references)