Refactored and implemented match result object with simplified results

This commit is contained in:
Maximilian Giller 2025-09-07 02:54:19 +02:00
parent bb5092f700
commit 426b36c985

View file

@ -1,103 +1,18 @@
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
import numpy as np
import cv2
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
def normalize_image(image: np.ndarray) -> np.ndarray:
return cv2.resize(image, (160*2,90*2))
return cv2.resize(image, (160 * 2, 90 * 2))
def compare_images(image: np.ndarray, target_image: np.ndarray) -> float:
err = np.mean((image.astype(float) - target_image.astype(float)) ** 2)
max_val = 255.0 if image.dtype == np.uint8 else 1.0
return 1 - (err / (max_val**2))
def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float:
video = cv2.VideoCapture(path_to_episode)
frame_number = 0
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
scores = []
while frame_number < frame_count:
video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
success, frame = video.read()
if not success:
break
frame = normalize_image(frame)
scores.append(compare_images(frame, image))
frame_number += 24
return np.max(scores) # np.percentile(scores, 90)
def plot_value(values):
import matplotlib.pyplot as plt
# Your data array
y_values = values
# x-axis is just the index
x_values = list(range(len(y_values)))
# Create the plot
plt.figure(figsize=(60, 20))
plt.plot(x_values, y_values, marker='o', linestyle='-', color='blue', label='Series A')
plt.title('Simple Plot')
plt.xlabel('Index')
plt.ylabel('Value')
plt.grid(True)
plt.legend()
# Show the plot
plt.savefig("plot.png")
def compare_episode_to_images(path_to_episode: str, images: list[np.ndarray]) -> float:
# return sum(compare_episode_to_image(path_to_episode, image) for image in images) / len(images)
video = cv2.VideoCapture(path_to_episode)
scores = []
frame_number = 0
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
while True:
# while frame_number < frame_count:
# video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
success, frame = video.read()
if not success:
break
frame = normalize_image(frame)
for ref_index in range(len(images)):
score = compare_images(frame, images[ref_index])
scores.append(score)
frame_number += 24
# print(scores)
video.set(cv2.CAP_PROP_POS_FRAMES, scores.index(max(scores)))
success, frame = video.read()
if success:
cv2.imwrite("match.png", normalize_image(frame))
cv2.imwrite("ref.png", normalize_image(images[0]))
video.release()
plot_value(scores)
return np.max(scores) # np.percentile(scores, 90)
def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]:
results = {}
# Compare to episodes
for episode in episodes:
results[episode] = efficient_episode_matching(episode, references)
print(results[episode])
return results
return float(1 - (err / (max_val**2)))
def _match_episode_batch(
@ -156,8 +71,10 @@ def _match_episode_batch(
def efficient_episode_matching(
path_to_episode: str,
references: dict[str, list[np.ndarray]],
*,
batch_size: int = 1440,
step_size: int = 24,
significance_threshold: float = 0.8,
) -> dict[str, float]:
"""_summary_
@ -166,6 +83,7 @@ def efficient_episode_matching(
references (dict[str, list[np.ndarray]]): _description_
batch_size (int, optional): Batch size in frame count for multicore frame matching. At least step_size+1. Defaults to 1440, 24fps for 1 minute.
step_size (int, optional): Number of frames to skip on intial match-step. Defaults to 24.
significance_threshold (float, optional): Matching threshold for coarse matching, above granular matching is performed. Defaults to 0.8.
Returns:
dict[str, float]: _description_
@ -175,7 +93,7 @@ def efficient_episode_matching(
# Load and prepare episode
video = cv2.VideoCapture(path_to_episode)
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
logging.info(f"Loading and preprocessing [{frame_count}] frames ...")
normalized_frames = []
@ -200,7 +118,13 @@ def efficient_episode_matching(
results: dict[str, float] = {ref: 0.0 for ref in references.keys()}
with ProcessPoolExecutor() as executor:
futures = [
executor.submit(_match_episode_batch, batch, references, step_size)
executor.submit(
_match_episode_batch,
batch,
references,
step_size,
significance_threshold,
)
for batch in batches
]
for i, future in enumerate(as_completed(futures)):
@ -215,20 +139,113 @@ def efficient_episode_matching(
logging.info(f"Finished [{path_to_episode}]")
results_msg = "Results"
for ref, score in results.items():
results_msg += f"\n{ref}\t{score:.4f}"
results_msg += f"\n{ref}\t{score:.4f}{'\tWinner' if score == max(results.values()) else ''}"
logging.info(results_msg)
return results
@dataclass
class EpisodesMatchResult:
scores: dict[str, dict[str, float]] = field(default_factory=dict)
"""Raw scores from matching. For each episode, the score to each reference."""
perfect_match: bool = False
"""If true, scores were unambiguous amd simplified to reference_by_episode and episode_by_reference."""
reference_by_episode: dict[str, str] | None = None
"""Simplified form of results, if unambiguous (see perfect_match). Episode as key, reference as value."""
@property
def episode_by_reference(self) -> dict[str, str] | None:
"""Simplified form of results, if unambiguous (see perfect_match). Refernce as key, episode as value."""
if self.reference_by_episode is None:
return None
return {ref: episode for episode, ref in self.reference_by_episode}
def simplify_match(self) -> None:
"""Simplifies score if perfect match given."""
ref_by_ep = {}
# Match for episodes
for episode, ref_scores in self.scores.items():
winner_score = max(ref_scores.values())
winner_ref = [
ref for ref, score in ref_scores.items() if score == winner_score
][0]
ref_by_ep[episode] = winner_ref
# Validate winners with matches for references
references = list(list(self.scores.values())[0].keys())
for ref in references:
episode_scores = {
episode: ref_scores[ref] for episode, ref_scores in self.scores.items()
}
winner_score = max(episode_scores.values())
winner_episode = [
epi for epi, score in episode_scores.items() if score == winner_score
][0]
# Make sure that no other episode as this ref as winner
# Winner episode might have a higher scoring ref, in case there are more refs than episodes
# => So ref might not be assigned at all
for episode, ep_ref in ref_by_ep.items():
if episode == winner_episode:
continue
elif ep_ref == ref:
return # Not a perfect match
self.perfect_match = True
self.reference_by_episode = ref_by_ep
def __str__(self) -> str:
msg = f"Perfect Match: {self.perfect_match}"
if self.perfect_match and self.reference_by_episode:
for ep, ref in self.reference_by_episode.items():
msg += f"\n{ep}\t{ref}"
else:
msg += "\nBy Episode:"
for ep, scores in self.scores.items():
msg += f"{ep}"
for ref, value in scores.items():
msg += f"\t{ref}\t{value}"
msg += "\nBy Reference:"
for ref in list(self.scores.values())[0].keys():
msg += f"{ref}"
for episode, scores in self.scores.items():
msg += f"\t{episode}\t{scores[ref]}"
return msg
def match_episodes_to_references(
episodes: list[str], references: dict[str, list[np.ndarray]]
) -> EpisodesMatchResult:
results = EpisodesMatchResult()
# Compare refs to episode
for episode in episodes:
results.scores[episode] = efficient_episode_matching(
episode, references, significance_threshold=0.9
)
results.simplify_match()
return results
if __name__ == "__main__":
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
episodes = [
"./episodes/B1_t05.mkv",
# "./episodes/B1_t06.mkv",
# "./episodes/B1_t07.mkv",
# "./episodes/B1_t08.mkv",
# "./episodes/B1_t09.mkv",
# "./episodes/B1_t10.mkv",
"./episodes/B1_t06.mkv",
"./episodes/B1_t07.mkv",
"./episodes/B1_t08.mkv",
"./episodes/B1_t09.mkv",
"./episodes/B1_t10.mkv",
]
references = {
@ -266,17 +283,4 @@ if __name__ == "__main__":
results = match_episodes_to_references(episodes, references)
print(results)
print("\nBy Episode:")
for ep, scores in results.items():
print(f"{ep}")
for ref, value in scores.items():
print(f"\t{ref}\t{value}")
print("\nBy Reference:")
for ref in list(results.values())[0].keys():
print(f"{ref}")
for episode, scores in results.items():
print(f"\t{episode}\t{scores[ref]}")
print("Done.")