From 426b36c9853e05e660f2ddb1ccf9b1130ce9c5cd Mon Sep 17 00:00:00 2001 From: Maximilian Giller Date: Sun, 7 Sep 2025 02:54:19 +0200 Subject: [PATCH] Refactored and implemented match result object with simplified results --- src/label_episodes.py | 224 +++++++++++++++++++++--------------------- 1 file changed, 114 insertions(+), 110 deletions(-) diff --git a/src/label_episodes.py b/src/label_episodes.py index 5016e49..231b583 100644 --- a/src/label_episodes.py +++ b/src/label_episodes.py @@ -1,103 +1,18 @@ from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field import numpy as np import cv2 import logging -logger = logging.getLogger() -logger.setLevel(logging.DEBUG) def normalize_image(image: np.ndarray) -> np.ndarray: - return cv2.resize(image, (160*2,90*2)) + return cv2.resize(image, (160 * 2, 90 * 2)) + def compare_images(image: np.ndarray, target_image: np.ndarray) -> float: err = np.mean((image.astype(float) - target_image.astype(float)) ** 2) max_val = 255.0 if image.dtype == np.uint8 else 1.0 - return 1 - (err / (max_val**2)) - -def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float: - video = cv2.VideoCapture(path_to_episode) - frame_number = 0 - frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) - scores = [] - while frame_number < frame_count: - video.set(cv2.CAP_PROP_POS_FRAMES, frame_number) - success, frame = video.read() - if not success: - break - frame = normalize_image(frame) - - scores.append(compare_images(frame, image)) - - frame_number += 24 - - return np.max(scores) # np.percentile(scores, 90) - -def plot_value(values): - import matplotlib.pyplot as plt - - # Your data array - y_values = values - - # x-axis is just the index - x_values = list(range(len(y_values))) - - # Create the plot - plt.figure(figsize=(60, 20)) - plt.plot(x_values, y_values, marker='o', linestyle='-', color='blue', label='Series A') - plt.title('Simple Plot') - plt.xlabel('Index') - plt.ylabel('Value') - plt.grid(True) - plt.legend() - - # Show the plot - plt.savefig("plot.png") - -def compare_episode_to_images(path_to_episode: str, images: list[np.ndarray]) -> float: - # return sum(compare_episode_to_image(path_to_episode, image) for image in images) / len(images) - - video = cv2.VideoCapture(path_to_episode) - scores = [] - frame_number = 0 - frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) - while True: - # while frame_number < frame_count: - # video.set(cv2.CAP_PROP_POS_FRAMES, frame_number) - success, frame = video.read() - if not success: - break - frame = normalize_image(frame) - - for ref_index in range(len(images)): - score = compare_images(frame, images[ref_index]) - scores.append(score) - - frame_number += 24 - - - # print(scores) - - video.set(cv2.CAP_PROP_POS_FRAMES, scores.index(max(scores))) - success, frame = video.read() - if success: - cv2.imwrite("match.png", normalize_image(frame)) - cv2.imwrite("ref.png", normalize_image(images[0])) - - video.release() - - plot_value(scores) - return np.max(scores) # np.percentile(scores, 90) - -def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]: - results = {} - - # Compare to episodes - for episode in episodes: - results[episode] = efficient_episode_matching(episode, references) - - print(results[episode]) - - return results + return float(1 - (err / (max_val**2))) def _match_episode_batch( @@ -156,8 +71,10 @@ def _match_episode_batch( def efficient_episode_matching( path_to_episode: str, references: dict[str, list[np.ndarray]], + *, batch_size: int = 1440, step_size: int = 24, + significance_threshold: float = 0.8, ) -> dict[str, float]: """_summary_ @@ -166,6 +83,7 @@ def efficient_episode_matching( references (dict[str, list[np.ndarray]]): _description_ batch_size (int, optional): Batch size in frame count for multicore frame matching. At least step_size+1. Defaults to 1440, 24fps for 1 minute. step_size (int, optional): Number of frames to skip on intial match-step. Defaults to 24. + significance_threshold (float, optional): Matching threshold for coarse matching, above granular matching is performed. Defaults to 0.8. Returns: dict[str, float]: _description_ @@ -175,7 +93,7 @@ def efficient_episode_matching( # Load and prepare episode video = cv2.VideoCapture(path_to_episode) - frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) logging.info(f"Loading and preprocessing [{frame_count}] frames ...") normalized_frames = [] @@ -200,7 +118,13 @@ def efficient_episode_matching( results: dict[str, float] = {ref: 0.0 for ref in references.keys()} with ProcessPoolExecutor() as executor: futures = [ - executor.submit(_match_episode_batch, batch, references, step_size) + executor.submit( + _match_episode_batch, + batch, + references, + step_size, + significance_threshold, + ) for batch in batches ] for i, future in enumerate(as_completed(futures)): @@ -215,20 +139,113 @@ def efficient_episode_matching( logging.info(f"Finished [{path_to_episode}]") results_msg = "Results" for ref, score in results.items(): - results_msg += f"\n{ref}\t{score:.4f}" + results_msg += f"\n{ref}\t{score:.4f}{'\tWinner' if score == max(results.values()) else ''}" logging.info(results_msg) return results +@dataclass +class EpisodesMatchResult: + scores: dict[str, dict[str, float]] = field(default_factory=dict) + """Raw scores from matching. For each episode, the score to each reference.""" + + perfect_match: bool = False + """If true, scores were unambiguous amd simplified to reference_by_episode and episode_by_reference.""" + + reference_by_episode: dict[str, str] | None = None + """Simplified form of results, if unambiguous (see perfect_match). Episode as key, reference as value.""" + + @property + def episode_by_reference(self) -> dict[str, str] | None: + """Simplified form of results, if unambiguous (see perfect_match). Refernce as key, episode as value.""" + if self.reference_by_episode is None: + return None + return {ref: episode for episode, ref in self.reference_by_episode} + + def simplify_match(self) -> None: + """Simplifies score if perfect match given.""" + ref_by_ep = {} + + # Match for episodes + for episode, ref_scores in self.scores.items(): + winner_score = max(ref_scores.values()) + winner_ref = [ + ref for ref, score in ref_scores.items() if score == winner_score + ][0] + ref_by_ep[episode] = winner_ref + + # Validate winners with matches for references + references = list(list(self.scores.values())[0].keys()) + for ref in references: + episode_scores = { + episode: ref_scores[ref] for episode, ref_scores in self.scores.items() + } + winner_score = max(episode_scores.values()) + winner_episode = [ + epi for epi, score in episode_scores.items() if score == winner_score + ][0] + + # Make sure that no other episode as this ref as winner + # Winner episode might have a higher scoring ref, in case there are more refs than episodes + # => So ref might not be assigned at all + for episode, ep_ref in ref_by_ep.items(): + if episode == winner_episode: + continue + elif ep_ref == ref: + return # Not a perfect match + + self.perfect_match = True + self.reference_by_episode = ref_by_ep + + def __str__(self) -> str: + msg = f"Perfect Match: {self.perfect_match}" + + if self.perfect_match and self.reference_by_episode: + for ep, ref in self.reference_by_episode.items(): + msg += f"\n{ep}\t{ref}" + else: + msg += "\nBy Episode:" + for ep, scores in self.scores.items(): + msg += f"{ep}" + for ref, value in scores.items(): + msg += f"\t{ref}\t{value}" + + msg += "\nBy Reference:" + for ref in list(self.scores.values())[0].keys(): + msg += f"{ref}" + for episode, scores in self.scores.items(): + msg += f"\t{episode}\t{scores[ref]}" + + return msg + + +def match_episodes_to_references( + episodes: list[str], references: dict[str, list[np.ndarray]] +) -> EpisodesMatchResult: + results = EpisodesMatchResult() + + # Compare refs to episode + for episode in episodes: + results.scores[episode] = efficient_episode_matching( + episode, references, significance_threshold=0.9 + ) + + results.simplify_match() + return results + + if __name__ == "__main__": + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + episodes = [ "./episodes/B1_t05.mkv", - # "./episodes/B1_t06.mkv", - # "./episodes/B1_t07.mkv", - # "./episodes/B1_t08.mkv", - # "./episodes/B1_t09.mkv", - # "./episodes/B1_t10.mkv", + "./episodes/B1_t06.mkv", + "./episodes/B1_t07.mkv", + "./episodes/B1_t08.mkv", + "./episodes/B1_t09.mkv", + "./episodes/B1_t10.mkv", ] references = { @@ -266,17 +283,4 @@ if __name__ == "__main__": results = match_episodes_to_references(episodes, references) print(results) - - print("\nBy Episode:") - for ep, scores in results.items(): - print(f"{ep}") - for ref, value in scores.items(): - print(f"\t{ref}\t{value}") - - print("\nBy Reference:") - for ref in list(results.values())[0].keys(): - print(f"{ref}") - for episode, scores in results.items(): - print(f"\t{episode}\t{scores[ref]}") - print("Done.")