Implemented coarse and granular matching

This commit is contained in:
Maximilian Giller 2025-09-06 13:09:59 +02:00
parent f8ee93a9c2
commit 337592f5d8

View file

@ -141,15 +141,55 @@ def match_episodes_to_references(episodes: list[str], references: dict[str, list
def _match_episode_batch(
frames: list[np.ndarray], references: dict[str, list[np.ndarray]], step_size: int
frames: list[np.ndarray],
references: dict[str, list[np.ndarray]],
step_size: int,
significance_threshold: float = 0.8,
) -> dict[str, float]:
"""Worker function: compare a batch of frames to all references, return local max per reference."""
"""
Worker function: compare a batch of frames to all references, return local max per reference.
Strategy:
- Compare every `step_size` frame.
- If a score >= significance_threshold is found, check the +/- (step_size - 1) neighborhood
to refine the max score.
"""
local_results = {ref: 0.0 for ref in references.keys()}
for frame in frames:
n_frames = len(frames)
significant_coarse_match: int | None = None
for ref, ref_images in references.items():
for ref_img in ref_images:
score = compare_images(frame, ref_img)
local_results[ref] = max(local_results[ref], score)
frame_index = 0
while frame_index < n_frames:
if significant_coarse_match == frame_index:
frame_index += 1 # Skip previously matched frame
continue
# Score current frame
score = compare_images(frames[frame_index], ref_img)
if score > local_results[ref]:
local_results[ref] = score
# Handle granular matching
if significant_coarse_match:
if frame_index >= significant_coarse_match + step_size:
if score < significance_threshold:
# End of granular matching
significant_coarse_match = None
frame_index += step_size
continue
else:
# Continue granular matching, but don't recheck previous section
significant_coarse_match = frame_index
frame_index += 1 # Continue granular matching
elif score >= significance_threshold:
frame_index -= step_size + 1 # Check previous frames
significant_coarse_match = frame_index # Mark significant match
else:
frame_index += step_size
return local_results