Refactored and implemented match result object with simplified results
This commit is contained in:
parent
bb5092f700
commit
426b36c985
1 changed files with 114 additions and 110 deletions
|
@ -1,103 +1,18 @@
|
|||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
import cv2
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
def normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
return cv2.resize(image, (160 * 2, 90 * 2))
|
||||
|
||||
|
||||
def compare_images(image: np.ndarray, target_image: np.ndarray) -> float:
|
||||
err = np.mean((image.astype(float) - target_image.astype(float)) ** 2)
|
||||
max_val = 255.0 if image.dtype == np.uint8 else 1.0
|
||||
return 1 - (err / (max_val**2))
|
||||
|
||||
def compare_episode_to_image(path_to_episode: str, image: np.ndarray) -> float:
|
||||
video = cv2.VideoCapture(path_to_episode)
|
||||
frame_number = 0
|
||||
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
scores = []
|
||||
while frame_number < frame_count:
|
||||
video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
frame = normalize_image(frame)
|
||||
|
||||
scores.append(compare_images(frame, image))
|
||||
|
||||
frame_number += 24
|
||||
|
||||
return np.max(scores) # np.percentile(scores, 90)
|
||||
|
||||
def plot_value(values):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Your data array
|
||||
y_values = values
|
||||
|
||||
# x-axis is just the index
|
||||
x_values = list(range(len(y_values)))
|
||||
|
||||
# Create the plot
|
||||
plt.figure(figsize=(60, 20))
|
||||
plt.plot(x_values, y_values, marker='o', linestyle='-', color='blue', label='Series A')
|
||||
plt.title('Simple Plot')
|
||||
plt.xlabel('Index')
|
||||
plt.ylabel('Value')
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
# Show the plot
|
||||
plt.savefig("plot.png")
|
||||
|
||||
def compare_episode_to_images(path_to_episode: str, images: list[np.ndarray]) -> float:
|
||||
# return sum(compare_episode_to_image(path_to_episode, image) for image in images) / len(images)
|
||||
|
||||
video = cv2.VideoCapture(path_to_episode)
|
||||
scores = []
|
||||
frame_number = 0
|
||||
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
while True:
|
||||
# while frame_number < frame_count:
|
||||
# video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
frame = normalize_image(frame)
|
||||
|
||||
for ref_index in range(len(images)):
|
||||
score = compare_images(frame, images[ref_index])
|
||||
scores.append(score)
|
||||
|
||||
frame_number += 24
|
||||
|
||||
|
||||
# print(scores)
|
||||
|
||||
video.set(cv2.CAP_PROP_POS_FRAMES, scores.index(max(scores)))
|
||||
success, frame = video.read()
|
||||
if success:
|
||||
cv2.imwrite("match.png", normalize_image(frame))
|
||||
cv2.imwrite("ref.png", normalize_image(images[0]))
|
||||
|
||||
video.release()
|
||||
|
||||
plot_value(scores)
|
||||
return np.max(scores) # np.percentile(scores, 90)
|
||||
|
||||
def match_episodes_to_references(episodes: list[str], references: dict[str, list[np.ndarray]]) -> dict[str, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
# Compare to episodes
|
||||
for episode in episodes:
|
||||
results[episode] = efficient_episode_matching(episode, references)
|
||||
|
||||
print(results[episode])
|
||||
|
||||
return results
|
||||
return float(1 - (err / (max_val**2)))
|
||||
|
||||
|
||||
def _match_episode_batch(
|
||||
|
@ -156,8 +71,10 @@ def _match_episode_batch(
|
|||
def efficient_episode_matching(
|
||||
path_to_episode: str,
|
||||
references: dict[str, list[np.ndarray]],
|
||||
*,
|
||||
batch_size: int = 1440,
|
||||
step_size: int = 24,
|
||||
significance_threshold: float = 0.8,
|
||||
) -> dict[str, float]:
|
||||
"""_summary_
|
||||
|
||||
|
@ -166,6 +83,7 @@ def efficient_episode_matching(
|
|||
references (dict[str, list[np.ndarray]]): _description_
|
||||
batch_size (int, optional): Batch size in frame count for multicore frame matching. At least step_size+1. Defaults to 1440, 24fps for 1 minute.
|
||||
step_size (int, optional): Number of frames to skip on intial match-step. Defaults to 24.
|
||||
significance_threshold (float, optional): Matching threshold for coarse matching, above granular matching is performed. Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: _description_
|
||||
|
@ -175,7 +93,7 @@ def efficient_episode_matching(
|
|||
|
||||
# Load and prepare episode
|
||||
video = cv2.VideoCapture(path_to_episode)
|
||||
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
logging.info(f"Loading and preprocessing [{frame_count}] frames ...")
|
||||
|
||||
normalized_frames = []
|
||||
|
@ -200,7 +118,13 @@ def efficient_episode_matching(
|
|||
results: dict[str, float] = {ref: 0.0 for ref in references.keys()}
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(_match_episode_batch, batch, references, step_size)
|
||||
executor.submit(
|
||||
_match_episode_batch,
|
||||
batch,
|
||||
references,
|
||||
step_size,
|
||||
significance_threshold,
|
||||
)
|
||||
for batch in batches
|
||||
]
|
||||
for i, future in enumerate(as_completed(futures)):
|
||||
|
@ -215,20 +139,113 @@ def efficient_episode_matching(
|
|||
logging.info(f"Finished [{path_to_episode}]")
|
||||
results_msg = "Results"
|
||||
for ref, score in results.items():
|
||||
results_msg += f"\n{ref}\t{score:.4f}"
|
||||
results_msg += f"\n{ref}\t{score:.4f}{'\tWinner' if score == max(results.values()) else ''}"
|
||||
logging.info(results_msg)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodesMatchResult:
|
||||
scores: dict[str, dict[str, float]] = field(default_factory=dict)
|
||||
"""Raw scores from matching. For each episode, the score to each reference."""
|
||||
|
||||
perfect_match: bool = False
|
||||
"""If true, scores were unambiguous amd simplified to reference_by_episode and episode_by_reference."""
|
||||
|
||||
reference_by_episode: dict[str, str] | None = None
|
||||
"""Simplified form of results, if unambiguous (see perfect_match). Episode as key, reference as value."""
|
||||
|
||||
@property
|
||||
def episode_by_reference(self) -> dict[str, str] | None:
|
||||
"""Simplified form of results, if unambiguous (see perfect_match). Refernce as key, episode as value."""
|
||||
if self.reference_by_episode is None:
|
||||
return None
|
||||
return {ref: episode for episode, ref in self.reference_by_episode}
|
||||
|
||||
def simplify_match(self) -> None:
|
||||
"""Simplifies score if perfect match given."""
|
||||
ref_by_ep = {}
|
||||
|
||||
# Match for episodes
|
||||
for episode, ref_scores in self.scores.items():
|
||||
winner_score = max(ref_scores.values())
|
||||
winner_ref = [
|
||||
ref for ref, score in ref_scores.items() if score == winner_score
|
||||
][0]
|
||||
ref_by_ep[episode] = winner_ref
|
||||
|
||||
# Validate winners with matches for references
|
||||
references = list(list(self.scores.values())[0].keys())
|
||||
for ref in references:
|
||||
episode_scores = {
|
||||
episode: ref_scores[ref] for episode, ref_scores in self.scores.items()
|
||||
}
|
||||
winner_score = max(episode_scores.values())
|
||||
winner_episode = [
|
||||
epi for epi, score in episode_scores.items() if score == winner_score
|
||||
][0]
|
||||
|
||||
# Make sure that no other episode as this ref as winner
|
||||
# Winner episode might have a higher scoring ref, in case there are more refs than episodes
|
||||
# => So ref might not be assigned at all
|
||||
for episode, ep_ref in ref_by_ep.items():
|
||||
if episode == winner_episode:
|
||||
continue
|
||||
elif ep_ref == ref:
|
||||
return # Not a perfect match
|
||||
|
||||
self.perfect_match = True
|
||||
self.reference_by_episode = ref_by_ep
|
||||
|
||||
def __str__(self) -> str:
|
||||
msg = f"Perfect Match: {self.perfect_match}"
|
||||
|
||||
if self.perfect_match and self.reference_by_episode:
|
||||
for ep, ref in self.reference_by_episode.items():
|
||||
msg += f"\n{ep}\t{ref}"
|
||||
else:
|
||||
msg += "\nBy Episode:"
|
||||
for ep, scores in self.scores.items():
|
||||
msg += f"{ep}"
|
||||
for ref, value in scores.items():
|
||||
msg += f"\t{ref}\t{value}"
|
||||
|
||||
msg += "\nBy Reference:"
|
||||
for ref in list(self.scores.values())[0].keys():
|
||||
msg += f"{ref}"
|
||||
for episode, scores in self.scores.items():
|
||||
msg += f"\t{episode}\t{scores[ref]}"
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def match_episodes_to_references(
|
||||
episodes: list[str], references: dict[str, list[np.ndarray]]
|
||||
) -> EpisodesMatchResult:
|
||||
results = EpisodesMatchResult()
|
||||
|
||||
# Compare refs to episode
|
||||
for episode in episodes:
|
||||
results.scores[episode] = efficient_episode_matching(
|
||||
episode, references, significance_threshold=0.9
|
||||
)
|
||||
|
||||
results.simplify_match()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
episodes = [
|
||||
"./episodes/B1_t05.mkv",
|
||||
# "./episodes/B1_t06.mkv",
|
||||
# "./episodes/B1_t07.mkv",
|
||||
# "./episodes/B1_t08.mkv",
|
||||
# "./episodes/B1_t09.mkv",
|
||||
# "./episodes/B1_t10.mkv",
|
||||
"./episodes/B1_t06.mkv",
|
||||
"./episodes/B1_t07.mkv",
|
||||
"./episodes/B1_t08.mkv",
|
||||
"./episodes/B1_t09.mkv",
|
||||
"./episodes/B1_t10.mkv",
|
||||
]
|
||||
|
||||
references = {
|
||||
|
@ -266,17 +283,4 @@ if __name__ == "__main__":
|
|||
results = match_episodes_to_references(episodes, references)
|
||||
|
||||
print(results)
|
||||
|
||||
print("\nBy Episode:")
|
||||
for ep, scores in results.items():
|
||||
print(f"{ep}")
|
||||
for ref, value in scores.items():
|
||||
print(f"\t{ref}\t{value}")
|
||||
|
||||
print("\nBy Reference:")
|
||||
for ref in list(results.values())[0].keys():
|
||||
print(f"{ref}")
|
||||
for episode, scores in results.items():
|
||||
print(f"\t{episode}\t{scores[ref]}")
|
||||
|
||||
print("Done.")
|
||||
|
|
Loading…
Reference in a new issue