Simplified compare algo
This commit is contained in:
parent
337592f5d8
commit
1bd489afc1
1 changed files with 30 additions and 73 deletions
|
@ -7,41 +7,9 @@ def normalize_image(image: np.ndarray) -> np.ndarray:
|
|||
return cv2.resize(image, (160*2,90*2))
|
||||
|
||||
def compare_images(image: np.ndarray, target_image: np.ndarray) -> float:
|
||||
"""
|
||||
Compare two images pixel-wise, threshold differences, and return a similarity score (0-1).
|
||||
|
||||
Parameters:
|
||||
image, target_image : np.ndarray
|
||||
Input images (must be same shape, HxWxC).
|
||||
|
||||
Returns:
|
||||
float : similarity score (1 = identical, 0 = completely different)
|
||||
"""
|
||||
THRESHOLD = 0.2 * 255
|
||||
|
||||
# Ensure images are uint8
|
||||
if image.dtype != np.uint8:
|
||||
image = (255 * image).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
|
||||
if target_image.dtype != np.uint8:
|
||||
target_image = (255 * target_image).astype(np.uint8) if target_image.max() <= 1.0 else target_image.astype(np.uint8)
|
||||
|
||||
# Ensure same shape
|
||||
if image.shape != target_image.shape:
|
||||
raise ValueError(f"Images must have the same shape: {image.shape} vs {target_image.shape}")
|
||||
|
||||
# Compute absolute difference per pixel
|
||||
delta = cv2.absdiff(image, target_image) # shape HxWxC
|
||||
|
||||
# If multi-channel, take max across channels
|
||||
if len(delta.shape) == 3:
|
||||
delta = delta.max(axis=2) # HxW
|
||||
|
||||
# Threshold to binary
|
||||
binary_diff = (delta > THRESHOLD).astype(np.uint8)
|
||||
|
||||
# Score = 1 - fraction of different pixels
|
||||
score = 1.0 - binary_diff.sum() / binary_diff.size
|
||||
return score
|
||||
err = np.mean((image.astype(float) - target_image.astype(float)) ** 2)
|
||||
max_val = 255.0 if image.dtype == np.uint8 else 1.0
|
||||
return 1 - (err / (max_val**2))
|
||||
|
||||
def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float:
|
||||
video = cv2.VideoCapture(path_to_episode)
|
||||
|
@ -120,20 +88,9 @@ def compare_episode_to_images(path_to_episode: str, images: list[np.ndarray]) ->
|
|||
def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
# Normalize references
|
||||
print("Normalizing reference images ...")
|
||||
for ref in references.keys():
|
||||
for i in range(len(references[ref])):
|
||||
references[ref][i] = normalize_image(references[ref][i])
|
||||
# cv2.imwrite(f"{i}.png", references[ref][i])
|
||||
|
||||
# Compare to episodes
|
||||
for episode in episodes:
|
||||
print(f"Processing: {episode}")
|
||||
results[episode] = {}
|
||||
for ref, images in references.items():
|
||||
print(f" -- Reference {ref}")
|
||||
results[episode][ref] = compare_episode_to_images(episode, images)
|
||||
results[episode] = efficient_episode_matching(episode, references)
|
||||
|
||||
print(results[episode])
|
||||
|
||||
|
@ -274,33 +231,33 @@ if __name__ == "__main__":
|
|||
references = {
|
||||
"1" : [
|
||||
cv2.imread("./1/9hYDr40VIWbAHGjmlpSkyqlDaVQ.webp"),
|
||||
# cv2.imread("./1/cmYUHroQ9zRhOYiTf0ozMWe1Bzs.webp"),
|
||||
# cv2.imread("./1/yLwaP7Y7O3Wyrht9lPmTDk7LSuG.webp"),
|
||||
cv2.imread("./1/cmYUHroQ9zRhOYiTf0ozMWe1Bzs.webp"),
|
||||
cv2.imread("./1/yLwaP7Y7O3Wyrht9lPmTDk7LSuG.webp"),
|
||||
],
|
||||
"2" : [
|
||||
cv2.imread("./2/g2pyyFRFBS18jYByRdiunWr7nap.webp"),
|
||||
cv2.imread("./2/wKCm3I7efXfsWGNrG28d3cIsL9B.webp"),
|
||||
cv2.imread("./2/yhL1JBnTDK30BtPbr8gZ5ZXCASL.webp"),
|
||||
],
|
||||
"3" : [
|
||||
cv2.imread("./3/d9F2x01XNhi65GTXwQjrdDJKgqm.webp"),
|
||||
cv2.imread("./3/kV7VAvLgjjhPTI56OnNvcNaHjDM.webp"),
|
||||
cv2.imread("./3/mLxkumagHSYg4KNLcEIgHpiJrx.webp"),
|
||||
],
|
||||
"4" : [
|
||||
cv2.imread("./4/9q62zw89SJiRD3kjNuhUIswWHTa.webp"),
|
||||
cv2.imread("./4/lyi4XUUssCacaEQs2FCGSDCKDUK.webp"),
|
||||
],
|
||||
"5" : [
|
||||
cv2.imread("./5/uIztge10KNlRU4KUFNXWPGCiG8e.webp"),
|
||||
cv2.imread("./5/vuhdflodFfPkvekVWbS4JZMYHa7.webp"),
|
||||
cv2.imread("./5/xmvi4bBk5TQyuL6tpiYvNoIPft.webp"),
|
||||
],
|
||||
"6" : [
|
||||
cv2.imread("./6/2I7WgvRwnYDNjbiztvbOAldYGde.webp"),
|
||||
cv2.imread("./6/dIaBagmUbP1AO6xFQzfQud5lk4I.webp"),
|
||||
cv2.imread("./6/ekDZrFRDJDyx78uttMyVT9mPhjL.webp"),
|
||||
],
|
||||
# "2" : [
|
||||
# cv2.imread("./2/g2pyyFRFBS18jYByRdiunWr7nap.webp"),
|
||||
# cv2.imread("./2/wKCm3I7efXfsWGNrG28d3cIsL9B.webp"),
|
||||
# cv2.imread("./2/yhL1JBnTDK30BtPbr8gZ5ZXCASL.webp"),
|
||||
# ],
|
||||
# "3" : [
|
||||
# cv2.imread("./3/d9F2x01XNhi65GTXwQjrdDJKgqm.webp"),
|
||||
# cv2.imread("./3/kV7VAvLgjjhPTI56OnNvcNaHjDM.webp"),
|
||||
# cv2.imread("./3/mLxkumagHSYg4KNLcEIgHpiJrx.webp"),
|
||||
# ],
|
||||
# "4" : [
|
||||
# cv2.imread("./4/9q62zw89SJiRD3kjNuhUIswWHTa.webp"),
|
||||
# cv2.imread("./4/lyi4XUUssCacaEQs2FCGSDCKDUK.webp"),
|
||||
# ],
|
||||
# "5" : [
|
||||
# cv2.imread("./5/uIztge10KNlRU4KUFNXWPGCiG8e.webp"),
|
||||
# cv2.imread("./5/vuhdflodFfPkvekVWbS4JZMYHa7.webp"),
|
||||
# cv2.imread("./5/xmvi4bBk5TQyuL6tpiYvNoIPft.webp"),
|
||||
# ],
|
||||
# "6" : [
|
||||
# cv2.imread("./6/2I7WgvRwnYDNjbiztvbOAldYGde.webp"),
|
||||
# cv2.imread("./6/dIaBagmUbP1AO6xFQzfQud5lk4I.webp"),
|
||||
# cv2.imread("./6/ekDZrFRDJDyx78uttMyVT9mPhjL.webp"),
|
||||
# ],
|
||||
}
|
||||
|
||||
results = match_episodes_to_references(episodes, references)
|
||||
|
|
Loading…
Reference in a new issue