[FFmpeg-devel] [PATCH 07/10] lavfi/dnn_backend_tf: Separate function for filling RequestItem and callback

Shubhanshu Saxena shubhanshu.e01 at gmail.com
Fri May 28 12:24:51 EEST 2021


This commit rearranges the existing code to create two separate functions
for filling request with execution data and the completion callback.

Signed-off-by: Shubhanshu Saxena <shubhanshu.e01 at gmail.com>
---
 libavfilter/dnn/dnn_backend_tf.c | 81 ++++++++++++++++++++++----------
 1 file changed, 57 insertions(+), 24 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 793b108e55..5d34da5db1 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -826,20 +826,16 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_
     return model;
 }
 
-static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_queue)
-{
-    TFModel *tf_model;
-    TFContext *ctx;
-    tf_infer_request *infer_request;
+static DNNReturnType fill_model_input_tf(TFModel *tf_model, RequestItem *request) {
+    DNNData input;
     InferenceItem *inference;
     TaskItem *task;
-    DNNData input, *outputs;
+    tf_infer_request *infer_request;
+    TFContext *ctx = &tf_model->ctx;
 
-    inference = ff_queue_pop_front(inference_queue);
+    inference = ff_queue_pop_front(tf_model->inference_queue);
     av_assert0(inference);
     task = inference->task;
-    tf_model = task->model;
-    ctx = &tf_model->ctx;
     request->inference = inference;
 
     if (get_input_tf(tf_model, &input, task->input_name) != DNN_SUCCESS)
@@ -852,7 +848,7 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que
     infer_request->tf_input = av_malloc(sizeof(TF_Output));
     infer_request->tf_input->oper = TF_GraphOperationByName(tf_model->graph, task->input_name);
     if (!infer_request->tf_input->oper){
-        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", task->input_name);
         return DNN_ERROR;
     }
     infer_request->tf_input->index = 0;
@@ -902,22 +898,23 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que
         infer_request->tf_outputs[i].index = 0;
     }
 
-    TF_SessionRun(tf_model->session, NULL,
-                    infer_request->tf_input, &infer_request->input_tensor, 1,
-                    infer_request->tf_outputs, infer_request->output_tensors,
-                    task->nb_output, NULL, 0, NULL,
-                    tf_model->status);
-    if (TF_GetCode(tf_model->status) != TF_OK) {
-            tf_free_request(infer_request);
-            av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n");
-            return DNN_ERROR;
-    }
+    return DNN_SUCCESS;
+}
+
+static void infer_completion_callback(void *args) {
+    RequestItem *request = args;
+    InferenceItem *inference = request->inference;
+    TaskItem *task = inference->task;
+    DNNData *outputs;
+    tf_infer_request *infer_request = request->infer_request;
+    TFModel *tf_model = task->model;
+    TFContext *ctx = &tf_model->ctx;
 
     outputs = av_malloc_array(task->nb_output, sizeof(*outputs));
     if (!outputs) {
         tf_free_request(infer_request);
         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *outputs\n");
-        return DNN_ERROR;
+        return;
     }
 
     for (uint32_t i = 0; i < task->nb_output; ++i) {
@@ -944,7 +941,7 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que
     case DFT_ANALYTICS_DETECT:
         if (!tf_model->model->detect_post_proc) {
             av_log(ctx, AV_LOG_ERROR, "Detect filter needs provide post proc\n");
-            return DNN_ERROR;
+            return;
         }
         tf_model->model->detect_post_proc(task->out_frame, outputs, task->nb_output, tf_model->model->filter_ctx);
         break;
@@ -955,7 +952,7 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que
             }
         }
         av_log(ctx, AV_LOG_ERROR, "Tensorflow backend does not support this kind of dnn filter now\n");
-        return DNN_ERROR;
+        return;
     }
     for (uint32_t i = 0; i < task->nb_output; ++i) {
         if (infer_request->output_tensors[i]) {
@@ -966,7 +963,43 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que
     tf_free_request(infer_request);
     av_freep(&outputs);
     ff_safe_queue_push_back(tf_model->request_queue, request);
-    return (task->inference_done == task->inference_todo) ? DNN_SUCCESS : DNN_ERROR;
+}
+
+static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_queue)
+{
+    TFModel *tf_model;
+    TFContext *ctx;
+    tf_infer_request *infer_request;
+    InferenceItem *inference;
+    TaskItem *task;
+
+    inference = ff_queue_peek_front(inference_queue);
+    task = inference->task;
+    tf_model = task->model;
+    ctx = &tf_model->ctx;
+
+    if (task->async) {
+        avpriv_report_missing_feature(ctx, "Async execution not supported");
+        return DNN_ERROR;
+    } else {
+        if (fill_model_input_tf(tf_model, request) != DNN_SUCCESS) {
+            return DNN_ERROR;
+        }
+
+        infer_request = request->infer_request;
+        TF_SessionRun(tf_model->session, NULL,
+                      infer_request->tf_input, &infer_request->input_tensor, 1,
+                      infer_request->tf_outputs, infer_request->output_tensors,
+                      task->nb_output, NULL, 0, NULL,
+                      tf_model->status);
+        if (TF_GetCode(tf_model->status) != TF_OK) {
+            tf_free_request(infer_request);
+            av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n");
+            return DNN_ERROR;
+        }
+        infer_completion_callback(request);
+        return (task->inference_done == task->inference_todo) ? DNN_SUCCESS : DNN_ERROR;
+    }
 }
 
 DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNExecBaseParams *exec_params)
-- 
2.25.1



More information about the ffmpeg-devel mailing list