[FFmpeg-devel] [PATCH 3/4] avfilter/vf_dnn_processing: add format GRAY8 and GRAYF32 support

Guo, Yejun yejun.guo at intel.com
Fri Nov 22 09:50:11 EET 2019


Signed-off-by: Guo, Yejun <yejun.guo at intel.com>
---
 doc/filters.texi                |   8 ++-
 libavfilter/vf_dnn_processing.c | 147 ++++++++++++++++++++++++++++++----------
 2 files changed, 118 insertions(+), 37 deletions(-)

diff --git a/doc/filters.texi b/doc/filters.texi
index 1f86ae1..c3f7997 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -8992,7 +8992,13 @@ Set the input name of the dnn network.
 Set the output name of the dnn network.
 
 @item fmt
-Set the pixel format for the Frame. Allowed values are @code{AV_PIX_FMT_RGB24}, and @code{AV_PIX_FMT_BGR24}.
+Set the pixel format for the Frame, the value is determined by the input of the dnn network model.
+
+If the model handles RGB (or BGR) image and the data type of model input is uint8, fmt must be @code{AV_PIX_FMT_RGB24} (or @code{AV_PIX_FMT_BGR24}.
+If the model handles RGB (or BGR) image and the data type of model input is float, fmt must be @code{AV_PIX_FMT_RGB24} (or @code{AV_PIX_FMT_BGR24}, and this filter will do data type conversion internally.
+If the model handles GRAY image and the data type of model input is uint8, fmt must be @code{AV_PIX_FMT_GRAY8}.
+If the model handles GRAY image and the data type of model input is float, fmt must be @code{AV_PIX_FMT_GRAYF32}.
+
 Default value is @code{AV_PIX_FMT_RGB24}.
 
 @end table
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index ce976ec..963dd5e 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -70,10 +70,12 @@ static av_cold int init(AVFilterContext *context)
 {
     DnnProcessingContext *ctx = context->priv;
     int supported = 0;
-    // as the first step, only rgb24 and bgr24 are supported
+    // to support more formats
     const enum AVPixelFormat supported_pixel_fmts[] = {
         AV_PIX_FMT_RGB24,
         AV_PIX_FMT_BGR24,
+        AV_PIX_FMT_GRAY8,
+        AV_PIX_FMT_GRAYF32,
     };
     for (int i = 0; i < sizeof(supported_pixel_fmts) / sizeof(enum AVPixelFormat); ++i) {
         if (supported_pixel_fmts[i] == ctx->fmt) {
@@ -156,14 +158,38 @@ static int config_input(AVFilterLink *inlink)
         return AVERROR(EIO);
     }
 
-    if (model_input.channels != 3) {
-        av_log(ctx, AV_LOG_ERROR, "the model requires input channels %d\n",
-                                   model_input.channels);
-        return AVERROR(EIO);
-    }
-    if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) {
-        av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
-        return AVERROR(EIO);
+    if (ctx->fmt == AV_PIX_FMT_RGB24 || ctx->fmt == AV_PIX_FMT_BGR24) {
+        if (model_input.channels != 3) {
+            av_log(ctx, AV_LOG_ERROR, "channel number 3 is required, but the actual channel number is %d\n",
+                                       model_input.channels);
+            return AVERROR(EIO);
+        }
+        if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) {
+            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
+            return AVERROR(EIO);
+        }
+    } else if (ctx->fmt == AV_PIX_FMT_GRAY8) {
+        if (model_input.channels != 1) {
+            av_log(ctx, AV_LOG_ERROR, "channel number 1 is required, but the actual channel number is %d\n",
+                                       model_input.channels);
+            return AVERROR(EIO);
+        }
+        if (model_input.dt != DNN_UINT8) {
+            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as uint8.\n");
+            return AVERROR(EIO);
+        }
+    } else if (ctx->fmt == AV_PIX_FMT_GRAYF32) {
+        if (model_input.channels != 1) {
+            av_log(ctx, AV_LOG_ERROR, "channel number 1 is required, but the actual channel number is %d\n",
+                                       model_input.channels);
+            return AVERROR(EIO);
+        }
+        if (model_input.dt != DNN_FLOAT) {
+            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float.\n");
+            return AVERROR(EIO);
+        }
+    } else {
+        av_assert0(!"should not reach here.");
     }
 
     ctx->input.width    = inlink->w;
@@ -203,28 +229,49 @@ static int config_output(AVFilterLink *outlink)
 
 static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
 {
-    // extend this function to support more formats
-    av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
-
-    if (dnn_input->dt == DNN_FLOAT) {
-        float *dnn_input_data = dnn_input->data;
-        for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
-                int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
-                dnn_input_data[t] = frame->data[0][k] / 255.0f;
+    if (frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24) {
+        if (dnn_input->dt == DNN_FLOAT) {
+            float *dnn_input_data = dnn_input->data;
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    dnn_input_data[t] = frame->data[0][k] / 255.0f;
+                }
+            }
+        } else {
+            uint8_t *dnn_input_data = dnn_input->data;
+            av_assert0(dnn_input->dt == DNN_UINT8);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    dnn_input_data[t] = frame->data[0][k];
+                }
             }
         }
-    } else {
+    } else if (frame->format == AV_PIX_FMT_GRAY8) {
         uint8_t *dnn_input_data = dnn_input->data;
         av_assert0(dnn_input->dt == DNN_UINT8);
         for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
+            for(int j = 0; j < frame->width; j++) {
                 int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
+                int t = i * frame->width + j;
                 dnn_input_data[t] = frame->data[0][k];
             }
         }
+    } else if (frame->format == AV_PIX_FMT_GRAYF32) {
+        float *dnn_input_data = dnn_input->data;
+        av_assert0(dnn_input->dt == DNN_FLOAT);
+        for (int i = 0; i < frame->height; i++) {
+            for(int j = 0; j < frame->width; j++) {
+                int k = i * frame->linesize[0] + j * sizeof(float);
+                int t = i * frame->width + j;
+                dnn_input_data[t] = *(float*)(frame->data[0] + k);
+            }
+        }
+    } else {
+        av_assert0(!"should not reach here.");
     }
 
     return 0;
@@ -232,28 +279,49 @@ static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
 
 static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output)
 {
-    // extend this function to support more formats
-    av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
-
-    if (dnn_output->dt == DNN_FLOAT) {
-        float *dnn_output_data = dnn_output->data;
-        for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
-                int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
-                frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
+    if (frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24) {
+        if (dnn_output->dt == DNN_FLOAT) {
+            float *dnn_output_data = dnn_output->data;
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
+                }
+            }
+        } else {
+            uint8_t *dnn_output_data = dnn_output->data;
+            av_assert0(dnn_output->dt == DNN_UINT8);
+            for (int i = 0; i < frame->height; i++) {
+                for(int j = 0; j < frame->width * 3; j++) {
+                    int k = i * frame->linesize[0] + j;
+                    int t = i * frame->width * 3 + j;
+                    frame->data[0][k] = dnn_output_data[t];
+                }
             }
         }
-    } else {
+    } else if (frame->format == AV_PIX_FMT_GRAY8) {
         uint8_t *dnn_output_data = dnn_output->data;
         av_assert0(dnn_output->dt == DNN_UINT8);
         for (int i = 0; i < frame->height; i++) {
-            for(int j = 0; j < frame->width * 3; j++) {
+            for(int j = 0; j < frame->width; j++) {
                 int k = i * frame->linesize[0] + j;
-                int t = i * frame->width * 3 + j;
+                int t = i * frame->width + j;
                 frame->data[0][k] = dnn_output_data[t];
             }
         }
+    } else if (frame->format == AV_PIX_FMT_GRAYF32) {
+        float *dnn_output_data = dnn_output->data;
+        av_assert0(dnn_output->dt == DNN_FLOAT);
+        for (int i = 0; i < frame->height; i++) {
+            for(int j = 0; j < frame->width; j++) {
+                int k = i * frame->linesize[0] + j * sizeof(float);
+                int t = i * frame->width + j;
+                *(float*)(frame->data[0] + k) = dnn_output_data[t];
+            }
+        }
+    } else {
+        av_assert0(!"should not reach here.");
     }
 
     return 0;
@@ -275,7 +343,14 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         av_frame_free(&in);
         return AVERROR(EIO);
     }
-    av_assert0(ctx->output.channels == 3);
+
+    if (ctx->fmt == AV_PIX_FMT_RGB24 || ctx->fmt == AV_PIX_FMT_BGR24) {
+        av_assert0(ctx->output.channels == 3);
+    } else if (ctx->fmt == AV_PIX_FMT_GRAY8 || ctx->fmt == AV_PIX_FMT_GRAYF32) {
+        av_assert0(ctx->output.channels == 1);
+    } else {
+        av_assert0(!"should not reach here");
+    }
 
     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {
-- 
2.7.4



More information about the ffmpeg-devel mailing list