From 1bd489afc1c4fa50bafd3165c160ad4fa8c6345a Mon Sep 17 00:00:00 2001 From: Max Date: Sat, 6 Sep 2025 19:11:46 +0200 Subject: [PATCH] Simplified compare algo --- src/label_episodes.py | 103 ++++++++++++------------------------------ 1 file changed, 30 insertions(+), 73 deletions(-) diff --git a/src/label_episodes.py b/src/label_episodes.py index 907fd8a..f8289da 100644 --- a/src/label_episodes.py +++ b/src/label_episodes.py @@ -7,41 +7,9 @@ def normalize_image(image: np.ndarray) -> np.ndarray: return cv2.resize(image, (160*2,90*2)) def compare_images(image: np.ndarray, target_image: np.ndarray) -> float: - """ - Compare two images pixel-wise, threshold differences, and return a similarity score (0-1). - - Parameters: - image, target_image : np.ndarray - Input images (must be same shape, HxWxC). - - Returns: - float : similarity score (1 = identical, 0 = completely different) - """ - THRESHOLD = 0.2 * 255 - - # Ensure images are uint8 - if image.dtype != np.uint8: - image = (255 * image).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) - if target_image.dtype != np.uint8: - target_image = (255 * target_image).astype(np.uint8) if target_image.max() <= 1.0 else target_image.astype(np.uint8) - - # Ensure same shape - if image.shape != target_image.shape: - raise ValueError(f"Images must have the same shape: {image.shape} vs {target_image.shape}") - - # Compute absolute difference per pixel - delta = cv2.absdiff(image, target_image) # shape HxWxC - - # If multi-channel, take max across channels - if len(delta.shape) == 3: - delta = delta.max(axis=2) # HxW - - # Threshold to binary - binary_diff = (delta > THRESHOLD).astype(np.uint8) - - # Score = 1 - fraction of different pixels - score = 1.0 - binary_diff.sum() / binary_diff.size - return score + err = np.mean((image.astype(float) - target_image.astype(float)) ** 2) + max_val = 255.0 if image.dtype == np.uint8 else 1.0 + return 1 - (err / (max_val**2)) def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float: video = cv2.VideoCapture(path_to_episode) @@ -120,20 +88,9 @@ def compare_episode_to_images(path_to_episode: str, images: list[np.ndarray]) -> def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]: results = {} - # Normalize references - print("Normalizing reference images ...") - for ref in references.keys(): - for i in range(len(references[ref])): - references[ref][i] = normalize_image(references[ref][i]) - # cv2.imwrite(f"{i}.png", references[ref][i]) - # Compare to episodes for episode in episodes: - print(f"Processing: {episode}") - results[episode] = {} - for ref, images in references.items(): - print(f" -- Reference {ref}") - results[episode][ref] = compare_episode_to_images(episode, images) + results[episode] = efficient_episode_matching(episode, references) print(results[episode]) @@ -274,33 +231,33 @@ if __name__ == "__main__": references = { "1" : [ cv2.imread("./1/9hYDr40VIWbAHGjmlpSkyqlDaVQ.webp"), - # cv2.imread("./1/cmYUHroQ9zRhOYiTf0ozMWe1Bzs.webp"), - # cv2.imread("./1/yLwaP7Y7O3Wyrht9lPmTDk7LSuG.webp"), + cv2.imread("./1/cmYUHroQ9zRhOYiTf0ozMWe1Bzs.webp"), + cv2.imread("./1/yLwaP7Y7O3Wyrht9lPmTDk7LSuG.webp"), + ], + "2" : [ + cv2.imread("./2/g2pyyFRFBS18jYByRdiunWr7nap.webp"), + cv2.imread("./2/wKCm3I7efXfsWGNrG28d3cIsL9B.webp"), + cv2.imread("./2/yhL1JBnTDK30BtPbr8gZ5ZXCASL.webp"), + ], + "3" : [ + cv2.imread("./3/d9F2x01XNhi65GTXwQjrdDJKgqm.webp"), + cv2.imread("./3/kV7VAvLgjjhPTI56OnNvcNaHjDM.webp"), + cv2.imread("./3/mLxkumagHSYg4KNLcEIgHpiJrx.webp"), + ], + "4" : [ + cv2.imread("./4/9q62zw89SJiRD3kjNuhUIswWHTa.webp"), + cv2.imread("./4/lyi4XUUssCacaEQs2FCGSDCKDUK.webp"), + ], + "5" : [ + cv2.imread("./5/uIztge10KNlRU4KUFNXWPGCiG8e.webp"), + cv2.imread("./5/vuhdflodFfPkvekVWbS4JZMYHa7.webp"), + cv2.imread("./5/xmvi4bBk5TQyuL6tpiYvNoIPft.webp"), + ], + "6" : [ + cv2.imread("./6/2I7WgvRwnYDNjbiztvbOAldYGde.webp"), + cv2.imread("./6/dIaBagmUbP1AO6xFQzfQud5lk4I.webp"), + cv2.imread("./6/ekDZrFRDJDyx78uttMyVT9mPhjL.webp"), ], - # "2" : [ - # cv2.imread("./2/g2pyyFRFBS18jYByRdiunWr7nap.webp"), - # cv2.imread("./2/wKCm3I7efXfsWGNrG28d3cIsL9B.webp"), - # cv2.imread("./2/yhL1JBnTDK30BtPbr8gZ5ZXCASL.webp"), - # ], - # "3" : [ - # cv2.imread("./3/d9F2x01XNhi65GTXwQjrdDJKgqm.webp"), - # cv2.imread("./3/kV7VAvLgjjhPTI56OnNvcNaHjDM.webp"), - # cv2.imread("./3/mLxkumagHSYg4KNLcEIgHpiJrx.webp"), - # ], - # "4" : [ - # cv2.imread("./4/9q62zw89SJiRD3kjNuhUIswWHTa.webp"), - # cv2.imread("./4/lyi4XUUssCacaEQs2FCGSDCKDUK.webp"), - # ], - # "5" : [ - # cv2.imread("./5/uIztge10KNlRU4KUFNXWPGCiG8e.webp"), - # cv2.imread("./5/vuhdflodFfPkvekVWbS4JZMYHa7.webp"), - # cv2.imread("./5/xmvi4bBk5TQyuL6tpiYvNoIPft.webp"), - # ], - # "6" : [ - # cv2.imread("./6/2I7WgvRwnYDNjbiztvbOAldYGde.webp"), - # cv2.imread("./6/dIaBagmUbP1AO6xFQzfQud5lk4I.webp"), - # cv2.imread("./6/ekDZrFRDJDyx78uttMyVT9mPhjL.webp"), - # ], } results = match_episodes_to_references(episodes, references)