[FFmpeg-cvslog] dnn: put DNNModel.set_input and DNNModule.execute_model together

Guo, Yejun git at videolan.org
Mon Sep 21 16:44:34 EEST 2020


ffmpeg | branch: master | Guo, Yejun <yejun.guo at intel.com> | Thu Sep 10 22:29:57 2020 +0800| [fce3e3e137843d86411f8868f18e1c3f472de0e5] | committer: Guo, Yejun

dnn: put DNNModel.set_input and DNNModule.execute_model together

suppose we have a detect and classify filter in the future, the
detect filter generates some bounding boxes (BBox) as AVFrame sidedata,
and the classify filter executes DNN model for each BBox. For each
BBox, we need to crop the AVFrame, copy data to DNN model input and do
the model execution. So we have to save the in_frame at DNNModel.set_input
and use it at DNNModule.execute_model, such saving is not feasible
when we support async execute_model.

This patch sets the in_frame as execution_model parameter, and so
all the information are put together within the same function for
each inference. It also makes easy to support BBox async inference.

> http://git.videolan.org/gitweb.cgi/ffmpeg.git/?a=commit;h=fce3e3e137843d86411f8868f18e1c3f472de0e5
---

 libavfilter/dnn/dnn_backend_native.c   | 119 ++++++++++++++-------------------
 libavfilter/dnn/dnn_backend_native.h   |   3 +-
 libavfilter/dnn/dnn_backend_openvino.c |  85 ++++++++++-------------
 libavfilter/dnn/dnn_backend_openvino.h |   3 +-
 libavfilter/dnn/dnn_backend_tf.c       | 118 ++++++++++++++------------------
 libavfilter/dnn/dnn_backend_tf.h       |   3 +-
 libavfilter/dnn_interface.h            |   8 +--
 libavfilter/vf_derain.c                |   9 +--
 libavfilter/vf_dnn_processing.c        |  18 ++---
 libavfilter/vf_sr.c                    |  25 ++-----
 10 files changed, 156 insertions(+), 235 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
index 14e878b6b8..dc47c9b542 100644
--- a/libavfilter/dnn/dnn_backend_native.c
+++ b/libavfilter/dnn/dnn_backend_native.c
@@ -70,64 +70,6 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i
     return DNN_ERROR;
 }
 
-static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name)
-{
-    NativeModel *native_model = (NativeModel *)model;
-    NativeContext *ctx = &native_model->ctx;
-    DnnOperand *oprd = NULL;
-    DNNData input;
-
-    if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
-        av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
-        return DNN_ERROR;
-    }
-
-    /* inputs */
-    for (int i = 0; i < native_model->operands_num; ++i) {
-        oprd = &native_model->operands[i];
-        if (strcmp(oprd->name, input_name) == 0) {
-            if (oprd->type != DOT_INPUT) {
-                av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name);
-                return DNN_ERROR;
-            }
-            break;
-        }
-        oprd = NULL;
-    }
-    if (!oprd) {
-        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
-        return DNN_ERROR;
-    }
-
-    oprd->dims[1] = frame->height;
-    oprd->dims[2] = frame->width;
-
-    av_freep(&oprd->data);
-    oprd->length = calculate_operand_data_length(oprd);
-    if (oprd->length <= 0) {
-        av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n");
-        return DNN_ERROR;
-    }
-    oprd->data = av_malloc(oprd->length);
-    if (!oprd->data) {
-        av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n");
-        return DNN_ERROR;
-    }
-
-    input.height = oprd->dims[1];
-    input.width = oprd->dims[2];
-    input.channels = oprd->dims[3];
-    input.data = oprd->data;
-    input.dt = oprd->data_type;
-    if (native_model->model->pre_proc != NULL) {
-        native_model->model->pre_proc(frame, &input, native_model->model->userdata);
-    } else {
-        proc_from_frame_to_dnn(frame, &input, ctx);
-    }
-
-    return DNN_SUCCESS;
-}
-
 // Loads model and its parameters that are stored in a binary file with following structure:
 // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters...
 // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases
@@ -273,7 +215,6 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
         return NULL;
     }
 
-    model->set_input = &set_input_native;
     model->get_input = &get_input_native;
     model->userdata = userdata;
 
@@ -285,26 +226,66 @@ fail:
     return NULL;
 }
 
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                          const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     NativeModel *native_model = (NativeModel *)model->model;
     NativeContext *ctx = &native_model->ctx;
     int32_t layer;
-    DNNData output;
+    DNNData input, output;
+    DnnOperand *oprd = NULL;
 
-    if (nb_output != 1) {
-        // currently, the filter does not need multiple outputs,
-        // so we just pending the support until we really need it.
-        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
+    if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
+        av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
         return DNN_ERROR;
     }
 
-    if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
-        av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
+    for (int i = 0; i < native_model->operands_num; ++i) {
+        oprd = &native_model->operands[i];
+        if (strcmp(oprd->name, input_name) == 0) {
+            if (oprd->type != DOT_INPUT) {
+                av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name);
+                return DNN_ERROR;
+            }
+            break;
+        }
+        oprd = NULL;
+    }
+    if (!oprd) {
+        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+        return DNN_ERROR;
+    }
+
+    oprd->dims[1] = in_frame->height;
+    oprd->dims[2] = in_frame->width;
+
+    av_freep(&oprd->data);
+    oprd->length = calculate_operand_data_length(oprd);
+    if (oprd->length <= 0) {
+        av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n");
         return DNN_ERROR;
     }
-    if (!native_model->operands[0].data) {
-        av_log(ctx, AV_LOG_ERROR, "Empty model input data\n");
+    oprd->data = av_malloc(oprd->length);
+    if (!oprd->data) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n");
+        return DNN_ERROR;
+    }
+
+    input.height = oprd->dims[1];
+    input.width = oprd->dims[2];
+    input.channels = oprd->dims[3];
+    input.data = oprd->data;
+    input.dt = oprd->data_type;
+    if (native_model->model->pre_proc != NULL) {
+        native_model->model->pre_proc(in_frame, &input, native_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(in_frame, &input, ctx);
+    }
+
+    if (nb_output != 1) {
+        // currently, the filter does not need multiple outputs,
+        // so we just pending the support until we really need it.
+        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
         return DNN_ERROR;
     }
 
diff --git a/libavfilter/dnn/dnn_backend_native.h b/libavfilter/dnn/dnn_backend_native.h
index 553438bd22..2f8d73fcf6 100644
--- a/libavfilter/dnn/dnn_backend_native.h
+++ b/libavfilter/dnn/dnn_backend_native.h
@@ -128,7 +128,8 @@ typedef struct NativeModel{
 
 DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                          const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_native(DNNModel **model);
 
diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index b1bad3f659..0dba1c1adc 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -48,7 +48,6 @@ typedef struct OVModel{
     ie_network_t *network;
     ie_executable_network_t *exe_network;
     ie_infer_request_t *infer_request;
-    ie_blob_t *input_blob;
 } OVModel;
 
 #define APPEND_STRING(generated_string, iterate_string)                                            \
@@ -133,49 +132,6 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
     return DNN_ERROR;
 }
 
-static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name)
-{
-    OVModel *ov_model = (OVModel *)model;
-    OVContext *ctx = &ov_model->ctx;
-    IEStatusCode status;
-    dimensions_t dims;
-    precision_e precision;
-    ie_blob_buffer_t blob_buffer;
-    DNNData input;
-
-    status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob);
-    if (status != OK)
-        goto err;
-
-    status |= ie_blob_get_dims(ov_model->input_blob, &dims);
-    status |= ie_blob_get_precision(ov_model->input_blob, &precision);
-    if (status != OK)
-        goto err;
-
-    status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer);
-    if (status != OK)
-        goto err;
-
-    input.height = dims.dims[2];
-    input.width = dims.dims[3];
-    input.channels = dims.dims[1];
-    input.data = blob_buffer.buffer;
-    input.dt = precision_to_datatype(precision);
-    if (ov_model->model->pre_proc != NULL) {
-        ov_model->model->pre_proc(frame, &input, ov_model->model->userdata);
-    } else {
-        proc_from_frame_to_dnn(frame, &input, ctx);
-    }
-
-    return DNN_SUCCESS;
-
-err:
-    if (ov_model->input_blob)
-        ie_blob_free(&ov_model->input_blob);
-    av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n");
-    return DNN_ERROR;
-}
-
 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata)
 {
     char *all_dev_names = NULL;
@@ -234,7 +190,6 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
         goto err;
 
     model->model = (void *)ov_model;
-    model->set_input = &set_input_ov;
     model->get_input = &get_input_ov;
     model->options = options;
     model->userdata = userdata;
@@ -258,7 +213,8 @@ err:
     return NULL;
 }
 
-DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     char *model_output_name = NULL;
     char *all_output_names = NULL;
@@ -269,7 +225,39 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output
     OVContext *ctx = &ov_model->ctx;
     IEStatusCode status;
     size_t model_output_count = 0;
-    DNNData output;
+    DNNData input, output;
+    ie_blob_t *input_blob = NULL;
+
+    status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &input_blob);
+    if (status != OK) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to get input blob\n");
+        return DNN_ERROR;
+    }
+
+    status |= ie_blob_get_dims(input_blob, &dims);
+    status |= ie_blob_get_precision(input_blob, &precision);
+    if (status != OK) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to get input blob dims/precision\n");
+        return DNN_ERROR;
+    }
+
+    status = ie_blob_get_buffer(input_blob, &blob_buffer);
+    if (status != OK) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n");
+        return DNN_ERROR;
+    }
+
+    input.height = dims.dims[2];
+    input.width = dims.dims[3];
+    input.channels = dims.dims[1];
+    input.data = blob_buffer.buffer;
+    input.dt = precision_to_datatype(precision);
+    if (ov_model->model->pre_proc != NULL) {
+        ov_model->model->pre_proc(in_frame, &input, ov_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(in_frame, &input, ctx);
+    }
+    ie_blob_free(&input_blob);
 
     if (nb_output != 1) {
         // currently, the filter does not need multiple outputs,
@@ -330,6 +318,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output
                 proc_from_dnn_to_frame(out_frame, &output, ctx);
             }
         }
+        ie_blob_free(&output_blob);
     }
 
     return DNN_SUCCESS;
@@ -339,8 +328,6 @@ void ff_dnn_free_model_ov(DNNModel **model)
 {
     if (*model){
         OVModel *ov_model = (OVModel *)(*model)->model;
-        if (ov_model->input_blob)
-            ie_blob_free(&ov_model->input_blob);
         if (ov_model->infer_request)
             ie_infer_request_free(&ov_model->infer_request);
         if (ov_model->exe_network)
diff --git a/libavfilter/dnn/dnn_backend_openvino.h b/libavfilter/dnn/dnn_backend_openvino.h
index efb349cb49..3f8f01da60 100644
--- a/libavfilter/dnn/dnn_backend_openvino.h
+++ b/libavfilter/dnn/dnn_backend_openvino.h
@@ -31,7 +31,8 @@
 
 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_ov(DNNModel **model);
 
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index c2d8c06931..8467f8a459 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -45,8 +45,6 @@ typedef struct TFModel{
     TF_Graph *graph;
     TF_Session *session;
     TF_Status *status;
-    TF_Output input;
-    TF_Tensor *input_tensor;
 } TFModel;
 
 static const AVClass dnn_tensorflow_class = {
@@ -152,48 +150,33 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
     return DNN_SUCCESS;
 }
 
-static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name)
+static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
 {
-    TFModel *tf_model = (TFModel *)model;
     TFContext *ctx = &tf_model->ctx;
-    DNNData input;
+    TF_Buffer *graph_def;
+    TF_ImportGraphDefOptions *graph_opts;
     TF_SessionOptions *sess_opts;
-    const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
-
-    if (get_input_tf(model, &input, input_name) != DNN_SUCCESS)
-        return DNN_ERROR;
-    input.height = frame->height;
-    input.width = frame->width;
+    const TF_Operation *init_op;
 
-    // Input operation
-    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
-    if (!tf_model->input.oper){
-        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+    graph_def = read_graph(model_filename);
+    if (!graph_def){
+        av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
         return DNN_ERROR;
     }
-    tf_model->input.index = 0;
-    if (tf_model->input_tensor){
-        TF_DeleteTensor(tf_model->input_tensor);
-    }
-    tf_model->input_tensor = allocate_input_tensor(&input);
-    if (!tf_model->input_tensor){
-        av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
+    tf_model->graph = TF_NewGraph();
+    tf_model->status = TF_NewStatus();
+    graph_opts = TF_NewImportGraphDefOptions();
+    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
+    TF_DeleteImportGraphDefOptions(graph_opts);
+    TF_DeleteBuffer(graph_def);
+    if (TF_GetCode(tf_model->status) != TF_OK){
+        TF_DeleteGraph(tf_model->graph);
+        TF_DeleteStatus(tf_model->status);
+        av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
         return DNN_ERROR;
     }
-    input.data = (float *)TF_TensorData(tf_model->input_tensor);
-
-    if (tf_model->model->pre_proc != NULL) {
-        tf_model->model->pre_proc(frame, &input, tf_model->model->userdata);
-    } else {
-        proc_from_frame_to_dnn(frame, &input, ctx);
-    }
-
-    // session
-    if (tf_model->session){
-        TF_CloseSession(tf_model->session, tf_model->status);
-        TF_DeleteSession(tf_model->session, tf_model->status);
-    }
 
+    init_op = TF_GraphOperationByName(tf_model->graph, "init");
     sess_opts = TF_NewSessionOptions();
     tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
     TF_DeleteSessionOptions(sess_opts);
@@ -219,33 +202,6 @@ static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input
     return DNN_SUCCESS;
 }
 
-static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
-{
-    TFContext *ctx = &tf_model->ctx;
-    TF_Buffer *graph_def;
-    TF_ImportGraphDefOptions *graph_opts;
-
-    graph_def = read_graph(model_filename);
-    if (!graph_def){
-        av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
-        return DNN_ERROR;
-    }
-    tf_model->graph = TF_NewGraph();
-    tf_model->status = TF_NewStatus();
-    graph_opts = TF_NewImportGraphDefOptions();
-    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
-    TF_DeleteImportGraphDefOptions(graph_opts);
-    TF_DeleteBuffer(graph_def);
-    if (TF_GetCode(tf_model->status) != TF_OK){
-        TF_DeleteGraph(tf_model->graph);
-        TF_DeleteStatus(tf_model->status);
-        av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
-        return DNN_ERROR;
-    }
-
-    return DNN_SUCCESS;
-}
-
 #define NAME_BUFFER_SIZE 256
 
 static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op,
@@ -626,7 +582,6 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     }
 
     model->model = (void *)tf_model;
-    model->set_input = &set_input_tf;
     model->get_input = &get_input_tf;
     model->options = options;
     model->userdata = userdata;
@@ -634,13 +589,40 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     return model;
 }
 
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     TF_Output *tf_outputs;
     TFModel *tf_model = (TFModel *)model->model;
     TFContext *ctx = &tf_model->ctx;
-    DNNData output;
+    DNNData input, output;
     TF_Tensor **output_tensors;
+    TF_Output tf_input;
+    TF_Tensor *input_tensor;
+
+    if (get_input_tf(tf_model, &input, input_name) != DNN_SUCCESS)
+        return DNN_ERROR;
+    input.height = in_frame->height;
+    input.width = in_frame->width;
+
+    tf_input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
+    if (!tf_input.oper){
+        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+        return DNN_ERROR;
+    }
+    tf_input.index = 0;
+    input_tensor = allocate_input_tensor(&input);
+    if (!input_tensor){
+        av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
+        return DNN_ERROR;
+    }
+    input.data = (float *)TF_TensorData(input_tensor);
+
+    if (tf_model->model->pre_proc != NULL) {
+        tf_model->model->pre_proc(in_frame, &input, tf_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(in_frame, &input, ctx);
+    }
 
     if (nb_output != 1) {
         // currently, the filter does not need multiple outputs,
@@ -674,7 +656,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output
     }
 
     TF_SessionRun(tf_model->session, NULL,
-                  &tf_model->input, &tf_model->input_tensor, 1,
+                  &tf_input, &input_tensor, 1,
                   tf_outputs, output_tensors, nb_output,
                   NULL, 0, NULL, tf_model->status);
     if (TF_GetCode(tf_model->status) != TF_OK) {
@@ -708,6 +690,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output
             TF_DeleteTensor(output_tensors[i]);
         }
     }
+    TF_DeleteTensor(input_tensor);
     av_freep(&output_tensors);
     av_freep(&tf_outputs);
     return DNN_SUCCESS;
@@ -729,9 +712,6 @@ void ff_dnn_free_model_tf(DNNModel **model)
         if (tf_model->status){
             TF_DeleteStatus(tf_model->status);
         }
-        if (tf_model->input_tensor){
-            TF_DeleteTensor(tf_model->input_tensor);
-        }
         av_freep(&tf_model);
         av_freep(model);
     }
diff --git a/libavfilter/dnn/dnn_backend_tf.h b/libavfilter/dnn/dnn_backend_tf.h
index f379e83d8d..1e00669736 100644
--- a/libavfilter/dnn/dnn_backend_tf.h
+++ b/libavfilter/dnn/dnn_backend_tf.h
@@ -31,7 +31,8 @@
 
 DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_tf(DNNModel **model);
 
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 6debc50607..0369ee4f71 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -51,9 +51,6 @@ typedef struct DNNModel{
     // Gets model input information
     // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
     DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
-    // Sets model input.
-    // Should be called every time before model execution.
-    DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name);
     // set the pre process to transfer data from AVFrame to DNNData
     // the default implementation within DNN is used if it is not provided by the filter
     int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data);
@@ -66,8 +63,9 @@ typedef struct DNNModel{
 typedef struct DNNModule{
     // Loads model and parameters from given file. Returns NULL if it is not possible.
     DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata);
-    // Executes model with specified output. Returns DNN_ERROR otherwise.
-    DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+    // Executes model with specified input and output. Returns DNN_ERROR otherwise.
+    DNNReturnType (*execute_model)(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                   const char **output_names, uint32_t nb_output, AVFrame *out_frame);
     // Frees memory allocated for model.
     void (*free_model)(DNNModel **model);
 } DNNModule;
diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c
index a59cd6e941..77dd401263 100644
--- a/libavfilter/vf_derain.c
+++ b/libavfilter/vf_derain.c
@@ -80,13 +80,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     const char *model_output_name = "y";
     AVFrame *out;
 
-    dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x");
-    if (dnn_result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
-        av_frame_free(&in);
-        return AVERROR(EIO);
-    }
-
     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {
         av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n");
@@ -95,7 +88,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     }
     av_frame_copy_props(out, in);
 
-    dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out);
+    dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, "x", in, &model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         av_frame_free(&in);
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index d7462bc828..2c8578c9b0 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -236,15 +236,11 @@ static int config_output(AVFilterLink *outlink)
     AVFrame *out = NULL;
 
     AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname);
-    if (result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
-    }
 
     // have a try run in case that the dnn model resize the frame
     out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+    result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, fake_in,
+                                              (const char **)&ctx->model_outputname, 1, out);
     if (result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         return AVERROR(EIO);
@@ -293,13 +289,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     DNNReturnType dnn_result;
     AVFrame *out;
 
-    dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname);
-    if (dnn_result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
-        av_frame_free(&in);
-        return AVERROR(EIO);
-    }
-
     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {
         av_frame_free(&in);
@@ -307,7 +296,8 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     }
     av_frame_copy_props(out, in);
 
-    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, in,
+                                                  (const char **)&ctx->model_outputname, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         av_frame_free(&in);
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 2eda8c3219..72a3137262 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -114,16 +114,11 @@ static int config_output(AVFilterLink *outlink)
     AVFrame *out = NULL;
     const char *model_output_name = "y";
 
-    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->model->set_input)(ctx->model->model, fake_in, "x");
-    if (result != DNN_SUCCESS) {
-        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
-    }
-
     // have a try run in case that the dnn model resize the frame
+    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
     out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
+    result = (ctx->dnn_module->execute_model)(ctx->model, "x", fake_in,
+                                              (const char **)&model_output_name, 1, out);
     if (result != DNN_SUCCESS){
         av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
         return AVERROR(EIO);
@@ -178,19 +173,13 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         sws_scale(ctx->sws_pre_scale,
                     (const uint8_t **)in->data, in->linesize, 0, in->height,
                     out->data, out->linesize);
-        dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x");
+        dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", out,
+                                                      (const char **)&model_output_name, 1, out);
     } else {
-        dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x");
-    }
-
-    if (dnn_result != DNN_SUCCESS) {
-        av_frame_free(&in);
-        av_frame_free(&out);
-        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
+        dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", in,
+                                                      (const char **)&model_output_name, 1, out);
     }
 
-    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n");
         av_frame_free(&in);



More information about the ffmpeg-cvslog mailing list