diff --git a/inference/transnetv2.py b/inference/transnetv2.py index 92be82c..ee541fa 100644 --- a/inference/transnetv2.py +++ b/inference/transnetv2.py @@ -32,7 +32,7 @@ def predict_raw(self, frames: np.ndarray): return single_frame_pred, all_frames_pred - def predict_frames(self, frames: np.ndarray): + def predict_frames(self, frames: np.ndarray, silent : bool = False): assert len(frames.shape) == 4 and frames.shape[1:] == self._input_size, \ "[TransNetV2] Input shape must be [frames, height, width, 3]." @@ -61,17 +61,22 @@ def input_iterator(): predictions.append((single_frame_pred.numpy()[0, 25:75, 0], all_frames_pred.numpy()[0, 25:75, 0])) - print("\r[TransNetV2] Processing video frames {}/{}".format( - min(len(predictions) * 50, len(frames)), len(frames) - ), end="") - print("") + if (not silent): + + print("\r[TransNetV2] Processing video frames {}/{}".format( + min(len(predictions) * 50, len(frames)), len(frames) + ), end="") + + if (not silent): + + print("") single_frame_pred = np.concatenate([single_ for single_, all_ in predictions]) all_frames_pred = np.concatenate([all_ for single_, all_ in predictions]) return single_frame_pred[:len(frames)], all_frames_pred[:len(frames)] # remove extra padded frames - def predict_video(self, video_fn: str): + def predict_video(self, video_fn: str, silent : bool = False): try: import ffmpeg except ModuleNotFoundError: @@ -79,13 +84,16 @@ def predict_video(self, video_fn: str): "individual frames from video file. Install `ffmpeg` command line tool and then " "install python wrapper by `pip install ffmpeg-python`.") - print("[TransNetV2] Extracting frames from {}".format(video_fn)) + if (not silent): + + print("[TransNetV2] Extracting frames from {}".format(video_fn)) + video_stream, err = ffmpeg.input(video_fn).output( "pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27" ).run(capture_stdout=True, capture_stderr=True) video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3]) - return (video, *self.predict_frames(video)) + return (video, *self.predict_frames(video, silent = silent)) @staticmethod def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5):