[FFmpeg-devel] [PATCH 2/2] libavfi/dnn: add LibTorch as one of DNN backend

Jean-Baptiste Kempf jb at videolan.org
Mon May 23 12:51:06 EEST 2022


Hello,

Are we seriously going to add all backends for ML in FFmpeg? Next one is ONNNX?

jb

On Mon, 23 May 2022, at 11:29, Ting Fu wrote:
> PyTorch is an open source machine learning framework that accelerates
> the path from research prototyping to production deployment. Official
> websit: https://pytorch.org/. We call the C++ library of PyTorch as
> LibTorch, the same below.
>
> To build FFmpeg with LibTorch, please take following steps as reference:
> 1. download LibTorch C++ library in 
> https://pytorch.org/get-started/locally/,
> please select C++/Java for language, and other options as your need.
> 2. unzip the file to your own dir, with command
> unzip libtorch-shared-with-deps-latest.zip -d your_dir
> 3. export libtorch_root/libtorch/include and
> libtorch_root/libtorch/include/torch/csrc/api/include to $PATH
> export libtorch_root/libtorch/lib/ to $LD_LIBRARY_PATH
> 4. config FFmpeg with ../configure --enable-libtorch 
> --extra-cflag=-I/libtorch_root/libtorch/include 
> --extra-cflag=-I/libtorch_root/libtorch/include/torch/csrc/api/include 
> --extra-ldflags=-L/libtorch_root/libtorch/lib/
> 5. make
>
> To run FFmpeg DNN inference with LibTorch backend:
> ./ffmpeg -i input.jpg -vf 
> dnn_processing=dnn_backend=torch:model=LibTorch_model.pt -y output.jpg
> The LibTorch_model.pt can be generated by Python with 
> torch.jit.script() api. Please note, torch.jit.trace() is not 
> recommanded, since it does not support ambiguous input size.
>
> Signed-off-by: Ting Fu <ting.fu at intel.com>
> ---
>  configure                             |   7 +-
>  libavfilter/dnn/Makefile              |   1 +
>  libavfilter/dnn/dnn_backend_torch.cpp | 567 ++++++++++++++++++++++++++
>  libavfilter/dnn/dnn_backend_torch.h   |  47 +++
>  libavfilter/dnn/dnn_interface.c       |  12 +
>  libavfilter/dnn/dnn_io_proc.c         | 117 +++++-
>  libavfilter/dnn_filter_common.c       |  31 +-
>  libavfilter/dnn_interface.h           |   3 +-
>  libavfilter/vf_dnn_processing.c       |   3 +
>  9 files changed, 774 insertions(+), 14 deletions(-)
>  create mode 100644 libavfilter/dnn/dnn_backend_torch.cpp
>  create mode 100644 libavfilter/dnn/dnn_backend_torch.h
>
> diff --git a/configure b/configure
> index f115b21064..85ce3e67a3 100755
> --- a/configure
> +++ b/configure
> @@ -279,6 +279,7 @@ External library support:
>    --enable-libtheora       enable Theora encoding via libtheora [no]
>    --enable-libtls          enable LibreSSL (via libtls), needed for 
> https support
>                             if openssl, gnutls or mbedtls is not used 
> [no]
> +  --enable-libtorch        enable Torch as one DNN backend
>    --enable-libtwolame      enable MP2 encoding via libtwolame [no]
>    --enable-libuavs3d       enable AVS3 decoding via libuavs3d [no]
>    --enable-libv4l2         enable libv4l2/v4l-utils [no]
> @@ -1850,6 +1851,7 @@ EXTERNAL_LIBRARY_LIST="
>      libopus
>      libplacebo
>      libpulse
> +    libtorch
>      librabbitmq
>      librav1e
>      librist
> @@ -2719,7 +2721,7 @@ dct_select="rdft"
>  deflate_wrapper_deps="zlib"
>  dirac_parse_select="golomb"
>  dovi_rpu_select="golomb"
> -dnn_suggest="libtensorflow libopenvino"
> +dnn_suggest="libtensorflow libopenvino libtorch"
>  dnn_deps="avformat swscale"
>  error_resilience_select="me_cmp"
>  faandct_deps="faan"
> @@ -6600,6 +6602,7 @@ enabled libopus           && {
>  }
>  enabled libplacebo        && require_pkg_config libplacebo "libplacebo 
> >= 4.192.0" libplacebo/vulkan.h pl_vulkan_create
>  enabled libpulse          && require_pkg_config libpulse libpulse 
> pulse/pulseaudio.h pa_context_new
> +enabled libtorch          && add_cppflags -D_GLIBCXX_USE_CXX11_ABI=0 
> && check_cxxflags -std=c++14 && require_cpp libtorch torch/torch.h 
> "torch::Tensor" -ltorch -lc10 -ltorch_cpu -lstdc++ -lpthread
>  enabled librabbitmq       && require_pkg_config librabbitmq 
> "librabbitmq >= 0.7.1" amqp.h amqp_new_connection
>  enabled librav1e          && require_pkg_config librav1e "rav1e >= 
> 0.4.0" rav1e.h rav1e_context_new
>  enabled librist           && require_pkg_config librist "librist >= 
> 0.2" librist/librist.h rist_receiver_create
> @@ -7025,6 +7028,8 @@ check_disable_warning -Wno-pointer-sign
>  check_disable_warning -Wno-unused-const-variable
>  check_disable_warning -Wno-bool-operation
>  check_disable_warning -Wno-char-subscripts
> +#this option is for supress redundant-decls warning in compile libtorch
> +check_disable_warning -Wno-redundant-decls
> 
>  check_disable_warning_headers(){
>      warning_flag=-W${1#-Wno-}
> diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile
> index 4cfbce0efc..d44dcb847e 100644
> --- a/libavfilter/dnn/Makefile
> +++ b/libavfilter/dnn/Makefile
> @@ -16,5 +16,6 @@ OBJS-$(CONFIG_DNN)                           += 
> dnn/dnn_backend_native_layer_mat
> 
>  DNN-OBJS-$(CONFIG_LIBTENSORFLOW)             += dnn/dnn_backend_tf.o
>  DNN-OBJS-$(CONFIG_LIBOPENVINO)               += dnn/dnn_backend_openvino.o
> +DNN-OBJS-$(CONFIG_LIBTORCH)                  += dnn/dnn_backend_torch.o
> 
>  OBJS-$(CONFIG_DNN)                           += $(DNN-OBJS-yes)
> diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
> b/libavfilter/dnn/dnn_backend_torch.cpp
> new file mode 100644
> index 0000000000..86cc018fbc
> --- /dev/null
> +++ b/libavfilter/dnn/dnn_backend_torch.cpp
> @@ -0,0 +1,567 @@
> +/*
> + * Copyright (c) 2022
> + *
> + * This file is part of FFmpeg.
> + *
> + * FFmpeg is free software; you can redistribute it and/or
> + * modify it under the terms of the GNU Lesser General Public
> + * License as published by the Free Software Foundation; either
> + * version 2.1 of the License, or (at your option) any later version.
> + *
> + * FFmpeg is distributed in the hope that it will be useful,
> + * but WITHOUT ANY WARRANTY; without even the implied warranty of
> + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
> + * Lesser General Public License for more details.
> + *
> + * You should have received a copy of the GNU Lesser General Public
> + * License along with FFmpeg; if not, write to the Free Software
> + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 
> 02110-1301 USA
> + */
> +
> +/**
> + * @file
> + * DNN Torch backend implementation.
> + */
> +
> +#include <torch/torch.h>
> +#include <torch/script.h>
> +#include "dnn_backend_torch.h"
> +
> +extern "C" {
> +#include "dnn_io_proc.h"
> +#include "../internal.h"
> +#include "dnn_backend_common.h"
> +#include "libavutil/opt.h"
> +#include "queue.h"
> +#include "safe_queue.h"
> +}
> +
> +typedef struct THOptions{
> +    char *device_name;
> +    c10::DeviceType device_type;
> +} THOptions;
> +
> +typedef struct THContext {
> +    const AVClass *c_class;
> +    THOptions options;
> +} THContext;
> +
> +typedef struct THModel {
> +    THContext ctx;
> +    DNNModel *model;
> +    torch::jit::Module jit_model;
> +    SafeQueue *request_queue;
> +    Queue *task_queue;
> +    Queue *lltask_queue;
> +} THModel;
> +
> +typedef struct THInferRequest {
> +    torch::Tensor *output;
> +    torch::Tensor *input_tensor;
> +} THInferRequest;
> +
> +typedef struct THRequestItem {
> +    THInferRequest *infer_request;
> +    LastLevelTaskItem *lltask;
> +    DNNAsyncExecModule exec_module;
> +} THRequestItem;
> +
> +
> +#define OFFSET(x) offsetof(THContext, x)
> +#define FLAGS AV_OPT_FLAG_FILTERING_PARAM
> +static const AVOption dnn_th_options[] = {
> +    { "device", "device to run model", OFFSET(options.device_name), 
> AV_OPT_TYPE_STRING, { .str = "cpu" }, 0, 0, FLAGS },
> +    { NULL }
> +};
> +
> +AVFILTER_DEFINE_CLASS(dnn_th);
> +
> +static int execute_model_th(THRequestItem *request, Queue 
> *lltask_queue);
> +static int th_start_inference(void *args);
> +static void infer_completion_callback(void *args);
> +
> +static int extract_lltask_from_task(TaskItem *task, Queue 
> *lltask_queue)
> +{
> +    THModel *th_model = (THModel *)task->model;
> +    THContext *ctx = &th_model->ctx;
> +    LastLevelTaskItem *lltask = (LastLevelTaskItem 
> *)av_malloc(sizeof(*lltask));
> +    if (!lltask) {
> +        av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for 
> LastLevelTaskItem\n");
> +        return AVERROR(ENOMEM);
> +    }
> +    task->inference_todo = 1;
> +    task->inference_done = 0;
> +    lltask->task = task;
> +    if (ff_queue_push_back(lltask_queue, lltask) < 0) {
> +        av_log(ctx, AV_LOG_ERROR, "Failed to push back 
> lltask_queue.\n");
> +        av_freep(&lltask);
> +        return AVERROR(ENOMEM);
> +    }
> +    return 0;
> +}
> +
> +static int get_input_th(void *model, DNNData *input, const char 
> *input_name)
> +{
> +    input->dt = DNN_FLOAT;
> +    input->order = DCO_RGB_PLANAR;
> +    input->height = -1;
> +    input->width = -1;
> +    input->channels = 3;
> +    return 0;
> +}
> +
> +static int get_output_th(void *model, const char *input_name, int 
> input_width, int input_height,
> +                                   const char *output_name, int 
> *output_width, int *output_height)
> +{
> +    int ret = 0;
> +    THModel *th_model = (THModel*) model;
> +    THContext *ctx = &th_model->ctx;
> +    TaskItem task;
> +    THRequestItem *request;
> +    DNNExecBaseParams exec_params = {
> +        .input_name     = input_name,
> +        .output_names   = &output_name,
> +        .nb_output      = 1,
> +        .in_frame       = NULL,
> +        .out_frame      = NULL,
> +    };
> +    ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, 
> th_model, input_height, input_width, ctx);
> +    if ( ret != 0) {
> +        goto err;
> +    }
> +
> +    ret = extract_lltask_from_task(&task, th_model->lltask_queue);
> +    if ( ret != 0) {
> +        av_log(ctx, AV_LOG_ERROR, "unable to extract last level task 
> from task.\n");
> +        goto err;
> +    }
> +
> +    request = (THRequestItem*) 
> ff_safe_queue_pop_front(th_model->request_queue);
> +    if (!request) {
> +        av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
> +        ret = AVERROR(EINVAL);
> +        goto err;
> +    }
> +
> +    ret = execute_model_th(request, th_model->lltask_queue);
> +    *output_width = task.out_frame->width;
> +    *output_height = task.out_frame->height;
> +
> +err:
> +    av_frame_free(&task.out_frame);
> +    av_frame_free(&task.in_frame);
> +    return ret;
> +}
> +
> +static void th_free_request(THInferRequest *request)
> +{
> +    if (!request)
> +        return;
> +    if (request->output) {
> +        delete(request->output);
> +        request->output = NULL;
> +    }
> +    if (request->input_tensor) {
> +        delete(request->input_tensor);
> +        request->input_tensor = NULL;
> +    }
> +    return;
> +}
> +
> +static inline void destroy_request_item(THRequestItem **arg)
> +{
> +    THRequestItem *item;
> +    if (!arg || !*arg) {
> +        return;
> +    }
> +    item = *arg;
> +    th_free_request(item->infer_request);
> +    av_freep(&item->infer_request);
> +    av_freep(&item->lltask);
> +    ff_dnn_async_module_cleanup(&item->exec_module);
> +    av_freep(arg);
> +}
> +
> +static THInferRequest *th_create_inference_request(void)
> +{
> +    THInferRequest *request = (THInferRequest 
> *)av_malloc(sizeof(THInferRequest));
> +    if (!request) {
> +        return NULL;
> +    }
> +    request->input_tensor = NULL;
> +    request->output = NULL;
> +    return request;
> +}
> +
> +DNNModel *ff_dnn_load_model_th(const char *model_filename, 
> DNNFunctionType func_type, const char *options, AVFilterContext 
> *filter_ctx)
> +{
> +    DNNModel *model = NULL;
> +    THModel *th_model = NULL;
> +    THRequestItem *item = NULL;
> +    THContext *ctx;
> +
> +    model = (DNNModel *)av_mallocz(sizeof(DNNModel));
> +    if (!model) {
> +        return NULL;
> +    }
> +
> +    th_model = (THModel *)av_mallocz(sizeof(THModel));
> +    if (!th_model) {
> +        av_freep(&model);
> +        return NULL;
> +    }
> +
> +    th_model->ctx.c_class = &dnn_th_class;
> +    ctx = &th_model->ctx;
> +    //parse options
> +    av_opt_set_defaults(ctx);
> +    if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0) {
> +        av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", 
> options);
> +        return NULL;
> +    }
> +
> +    c10::Device device = c10::Device(ctx->options.device_name);
> +    if (device.is_cpu()) {
> +        ctx->options.device_type = torch::kCPU;
> +    } else {
> +        av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", 
> ctx->options.device_name);
> +        goto fail;
> +    }
> +
> +    try {
> +        th_model->jit_model = torch::jit::load(model_filename, device);
> +    } catch (const c10::Error& e) {
> +        av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
> +        goto fail;
> +    }
> +
> +    th_model->request_queue = ff_safe_queue_create();
> +    if (!th_model->request_queue) {
> +        goto fail;
> +    }
> +
> +    item = (THRequestItem *)av_mallocz(sizeof(THRequestItem));
> +    if (!item) {
> +        goto fail;
> +    }
> +    item->lltask = NULL;
> +    item->infer_request = th_create_inference_request();
> +    if (!item->infer_request) {
> +        av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for 
> Torch inference request\n");
> +        goto fail;
> +    }
> +    item->exec_module.start_inference = &th_start_inference;
> +    item->exec_module.callback = &infer_completion_callback;
> +    item->exec_module.args = item;
> +
> +    if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) {
> +        goto fail;
> +    }
> +
> +    th_model->task_queue = ff_queue_create();
> +    if (!th_model->task_queue) {
> +        goto fail;
> +    }
> +
> +    th_model->lltask_queue = ff_queue_create();
> +    if (!th_model->lltask_queue) {
> +        goto fail;
> +    }
> +
> +    th_model->model = model;
> +    model->model = th_model;
> +    model->get_input = &get_input_th;
> +    model->get_output = &get_output_th;
> +    model->options = NULL;
> +    model->filter_ctx = filter_ctx;
> +    model->func_type = func_type;
> +    return model;
> +
> +fail:
> +    destroy_request_item(&item);
> +    ff_queue_destroy(th_model->task_queue);
> +    ff_queue_destroy(th_model->lltask_queue);
> +    ff_safe_queue_destroy(th_model->request_queue);
> +    av_freep(&th_model);
> +    av_freep(&model);
> +    av_freep(&item);
> +    return NULL;
> +}
> +
> +static int fill_model_input_th(THModel *th_model, THRequestItem 
> *request)
> +{
> +    LastLevelTaskItem *lltask = NULL;
> +    TaskItem *task = NULL;
> +    THInferRequest *infer_request = NULL;
> +    DNNData input;
> +    THContext *ctx = &th_model->ctx;
> +    int ret;
> +
> +    lltask = (LastLevelTaskItem 
> *)ff_queue_pop_front(th_model->lltask_queue);
> +    if (!lltask) {
> +        ret = AVERROR(EINVAL);
> +        goto err;
> +    }
> +    request->lltask = lltask;
> +    task = lltask->task;
> +    infer_request = request->infer_request;
> +
> +    ret = get_input_th(th_model, &input, NULL);
> +    if ( ret != 0) {
> +        goto err;
> +    }
> +
> +    input.height = task->in_frame->height;
> +    input.width = task->in_frame->width;
> +    input.data = malloc(input.height * input.width * 3 * 
> sizeof(float));
> +    if (!input.data)
> +        return AVERROR(ENOMEM);
> +    infer_request->input_tensor = new torch::Tensor();
> +    infer_request->output = new torch::Tensor();
> +
> +    switch (th_model->model->func_type) {
> +    case DFT_PROCESS_FRAME:
> +        if (task->do_ioproc) {
> +            if (th_model->model->frame_pre_proc != NULL) {
> +                th_model->model->frame_pre_proc(task->in_frame, 
> &input, th_model->model->filter_ctx);
> +            } else {
> +                ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx);
> +            }
> +        }
> +        break;
> +    default:
> +        avpriv_report_missing_feature(NULL, "model function type %d", 
> th_model->model->func_type);
> +        break;
> +    }
> +    *infer_request->input_tensor = torch::from_blob(input.data, {1, 1, 
> 3, input.height, input.width},
> +                                                    torch::kFloat32);
> +    return 0;
> +
> +err:
> +    th_free_request(infer_request);
> +    return ret;
> +}
> +
> +static int th_start_inference(void *args)
> +{
> +    THRequestItem *request = (THRequestItem *)args;
> +    THInferRequest *infer_request = NULL;
> +    LastLevelTaskItem *lltask = NULL;
> +    TaskItem *task = NULL;
> +    THModel *th_model = NULL;
> +    THContext *ctx = NULL;
> +    std::vector<torch::jit::IValue> inputs;
> +
> +    if (!request) {
> +        av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n");
> +        return AVERROR(EINVAL);
> +    }
> +    infer_request = request->infer_request;
> +    lltask = request->lltask;
> +    task = lltask->task;
> +    th_model = (THModel *)task->model;
> +    ctx = &th_model->ctx;
> +
> +    if (!infer_request->input_tensor || !infer_request->output) {
> +        av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
> +        return DNN_GENERIC_ERROR;
> +    }
> +    inputs.push_back(*infer_request->input_tensor);
> +
> +    auto parameters = th_model->jit_model.parameters();
> +    auto para = *(parameters.begin());
> +
> +    *infer_request->output = 
> th_model->jit_model.forward(inputs).toTensor();
> +
> +    return 0;
> +}
> +
> +static void infer_completion_callback(void *args) {
> +    THRequestItem *request = (THRequestItem*)args;
> +    LastLevelTaskItem *lltask = request->lltask;
> +    TaskItem *task = lltask->task;
> +    DNNData outputs;
> +    THInferRequest *infer_request = request->infer_request;
> +    THModel *th_model = (THModel *)task->model;
> +    torch::Tensor *output = infer_request->output;
> +
> +    c10::IntArrayRef sizes = output->sizes();
> +    assert(sizes.size == 5);
> +    outputs.order = DCO_RGB_PLANAR;
> +    outputs.height = sizes.at(3);
> +    outputs.width = sizes.at(4);
> +    outputs.dt = DNN_FLOAT;
> +    outputs.channels = 3;
> +
> +    switch (th_model->model->func_type) {
> +    case DFT_PROCESS_FRAME:
> +        if (task->do_ioproc) {
> +            outputs.data = output->data_ptr();
> +            if (th_model->model->frame_post_proc != NULL) {
> +                th_model->model->frame_post_proc(task->out_frame, 
> &outputs, th_model->model->filter_ctx);
> +            } else {
> +                ff_proc_from_dnn_to_frame(task->out_frame, &outputs, 
> &th_model->ctx);
> +            }
> +        } else {
> +            task->out_frame->width = outputs.width;
> +            task->out_frame->height = outputs.height;
> +        }
> +        break;
> +    default:
> +        avpriv_report_missing_feature(&th_model->ctx, "model function 
> type %d", th_model->model->func_type);
> +        goto err;
> +    }
> +    task->inference_done++;
> +err:
> +    th_free_request(infer_request);
> +
> +    if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) 
> {
> +        destroy_request_item(&request);
> +        av_log(&th_model->ctx, AV_LOG_ERROR, "Unable to push back 
> request_queue when failed to start inference.\n");
> +    }
> +}
> +
> +static int execute_model_th(THRequestItem *request, Queue 
> *lltask_queue)
> +{
> +    THModel *th_model = NULL;
> +    LastLevelTaskItem *lltask;
> +    TaskItem *task = NULL;
> +    int ret = 0;
> +
> +    if (ff_queue_size(lltask_queue) == 0) {
> +        destroy_request_item(&request);
> +        return 0;
> +    }
> +
> +    lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue);
> +    if (lltask == NULL) {
> +        av_log(NULL, AV_LOG_ERROR, "Failed to get 
> LastLevelTaskItem\n");
> +        ret = AVERROR(EINVAL);
> +        goto err;
> +    }
> +    task = lltask->task;
> +    th_model = (THModel *)task->model;
> +
> +    ret = fill_model_input_th(th_model, request);
> +    if ( ret != 0) {
> +        goto err;
> +    }
> +    if (task->async) {
> +        avpriv_report_missing_feature(&th_model->ctx, "LibTorch 
> async");
> +    } else {
> +        ret = th_start_inference((void *)(request));
> +        if (ret != 0) {
> +            goto err;
> +        }
> +        infer_completion_callback(request);
> +        return (task->inference_done == task->inference_todo) ? 0 : 
> DNN_GENERIC_ERROR;
> +    }
> +
> +err:
> +    th_free_request(request->infer_request);
> +    if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) 
> {
> +        destroy_request_item(&request);
> +    }
> +    return ret;
> +}
> +
> +int ff_dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams 
> *exec_params)
> +{
> +    THModel *th_model = (THModel *)model->model;
> +    THContext *ctx = &th_model->ctx;
> +    TaskItem *task;
> +    THRequestItem *request;
> +    int ret = 0;
> +
> +    ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, 
> exec_params);
> +    if (ret != 0) {
> +        return ret;
> +    }
> +
> +    task = (TaskItem *)av_malloc(sizeof(TaskItem));
> +    if (!task) {
> +        av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task 
> item.\n");
> +        return AVERROR(ENOMEM);
> +    }
> +
> +    ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1);
> +    if (ret != 0) {
> +        av_freep(&task);
> +        av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n");
> +        return ret;
> +    }
> +
> +    ret = ff_queue_push_back(th_model->task_queue, task);
> +    if (ret < 0) {
> +        av_freep(&task);
> +        av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n");
> +        return ret;
> +    }
> +
> +    ret = extract_lltask_from_task(task, th_model->lltask_queue);
> +    if (ret != 0) {
> +        av_log(ctx, AV_LOG_ERROR, "unable to extract last level task 
> from task.\n");
> +        return ret;
> +    }
> +
> +    request = (THRequestItem 
> *)ff_safe_queue_pop_front(th_model->request_queue);
> +    if (!request) {
> +        av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
> +        return AVERROR(EINVAL);
> +    }
> +
> +    return execute_model_th(request, th_model->lltask_queue);
> +}
> +
> +
> +int ff_dnn_flush_th(const DNNModel *model)
> +{
> +    THModel *th_model = (THModel *)model->model;
> +    THRequestItem *request;
> +
> +    if (ff_queue_size(th_model->lltask_queue) == 0) {
> +        // no pending task need to flush
> +        return 0;
> +    }
> +    request = (THRequestItem 
> *)ff_safe_queue_pop_front(th_model->request_queue);
> +    if (!request) {
> +        av_log(&th_model->ctx, AV_LOG_ERROR, "unable to get infer 
> request.\n");
> +        return AVERROR(EINVAL);
> +    }
> +
> +    return execute_model_th(request, th_model->lltask_queue);
> +}
> +
> +DNNAsyncStatusType ff_dnn_get_result_th(const DNNModel *model, AVFrame 
> **in, AVFrame **out)
> +{
> +    THModel *th_model = (THModel *)model->model;
> +    return ff_dnn_get_result_common(th_model->task_queue, in, out);
> +}
> +
> +void ff_dnn_free_model_th(DNNModel **model)
> +{
> +    THModel *th_model;
> +    if(*model) {
> +        th_model = (THModel *) (*model)->model;
> +        while (ff_safe_queue_size(th_model->request_queue) != 0) {
> +            THRequestItem *item = (THRequestItem 
> *)ff_safe_queue_pop_front(th_model->request_queue);
> +            destroy_request_item(&item);
> +        }
> +        ff_safe_queue_destroy(th_model->request_queue);
> +
> +        while (ff_queue_size(th_model->lltask_queue) != 0) {
> +            LastLevelTaskItem *item = (LastLevelTaskItem 
> *)ff_queue_pop_front(th_model->lltask_queue);
> +            av_freep(&item);
> +        }
> +        ff_queue_destroy(th_model->lltask_queue);
> +
> +        while (ff_queue_size(th_model->task_queue) != 0) {
> +            TaskItem *item = (TaskItem 
> *)ff_queue_pop_front(th_model->task_queue);
> +            av_frame_free(&item->in_frame);
> +            av_frame_free(&item->out_frame);
> +            av_freep(&item);
> +        }
> +    }
> +    av_freep(&th_model);
> +    av_freep(model);
> +}
> diff --git a/libavfilter/dnn/dnn_backend_torch.h 
> b/libavfilter/dnn/dnn_backend_torch.h
> new file mode 100644
> index 0000000000..5d6a08f85f
> --- /dev/null
> +++ b/libavfilter/dnn/dnn_backend_torch.h
> @@ -0,0 +1,47 @@
> +/*
> + * Copyright (c) 2022
> + *
> + * This file is part of FFmpeg.
> + *
> + * FFmpeg is free software; you can redistribute it and/or
> + * modify it under the terms of the GNU Lesser General Public
> + * License as published by the Free Software Foundation; either
> + * version 2.1 of the License, or (at your option) any later version.
> + *
> + * FFmpeg is distributed in the hope that it will be useful,
> + * but WITHOUT ANY WARRANTY; without even the implied warranty of
> + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
> + * Lesser General Public License for more details.
> + *
> + * You should have received a copy of the GNU Lesser General Public
> + * License along with FFmpeg; if not, write to the Free Software
> + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 
> 02110-1301 USA
> + */
> +
> +/**
> + * @file
> + * DNN inference functions interface for Torch backend.
> + */
> +
> +#ifndef AVFILTER_DNN_DNN_BACKEND_TORCH_H
> +#define AVFILTER_DNN_DNN_BACKEND_TORCH_H
> +
> +
> +#ifdef __cplusplus
> +extern "C" {
> +#endif
> +#include "../dnn_interface.h"
> +
> +DNNModel *ff_dnn_load_model_th(const char *model_filename, 
> DNNFunctionType func_type, const char *options, AVFilterContext 
> *filter_ctx);
> +
> +int ff_dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams 
> *exec_params);
> +DNNAsyncStatusType ff_dnn_get_result_th(const DNNModel *model, AVFrame 
> **in, AVFrame **out);
> +int ff_dnn_flush_th(const DNNModel *model);
> +
> +void ff_dnn_free_model_th(DNNModel **model);
> +
> +#ifdef __cplusplus
> +}
> +#endif
> +
> +#endif
> diff --git a/libavfilter/dnn/dnn_interface.c 
> b/libavfilter/dnn/dnn_interface.c
> index 554a36b0dc..6f4e02b481 100644
> --- a/libavfilter/dnn/dnn_interface.c
> +++ b/libavfilter/dnn/dnn_interface.c
> @@ -27,6 +27,7 @@
>  #include "dnn_backend_native.h"
>  #include "dnn_backend_tf.h"
>  #include "dnn_backend_openvino.h"
> +#include "dnn_backend_torch.h"
>  #include "libavutil/mem.h"
> 
>  DNNModule *ff_get_dnn_module(DNNBackendType backend_type)
> @@ -70,6 +71,17 @@ DNNModule *ff_get_dnn_module(DNNBackendType 
> backend_type)
>          return NULL;
>      #endif
>          break;
> +    case DNN_TH:
> +    #if (CONFIG_LIBTORCH == 1)
> +        dnn_module->load_model = &ff_dnn_load_model_th;
> +        dnn_module->execute_model = &ff_dnn_execute_model_th;
> +        dnn_module->get_result = &ff_dnn_get_result_th;
> +        dnn_module->flush = &ff_dnn_flush_th;
> +        dnn_module->free_model = &ff_dnn_free_model_th;
> +    #else
> +        av_freep(&dnn_module);
> +    #endif
> +        break;
>      default:
>          av_log(NULL, AV_LOG_ERROR, "Module backend_type is not native 
> or tensorflow\n");
>          av_freep(&dnn_module);
> diff --git a/libavfilter/dnn/dnn_io_proc.c 
> b/libavfilter/dnn/dnn_io_proc.c
> index 532b089002..cbaa1e601f 100644
> --- a/libavfilter/dnn/dnn_io_proc.c
> +++ b/libavfilter/dnn/dnn_io_proc.c
> @@ -24,10 +24,20 @@
>  #include "libavutil/avassert.h"
>  #include "libavutil/detection_bbox.h"
> 
> +static enum AVPixelFormat get_pixel_format(DNNData *data);
> +
>  int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void 
> *log_ctx)
>  {
>      struct SwsContext *sws_ctx;
> +    int frame_size = frame->height * frame->width;
> +    int linesize[3];
> +    void **dst_data, *middle_data;
> +    enum AVPixelFormat fmt;
>      int bytewidth = av_image_get_linesize(frame->format, frame->width, 
> 0);
> +    linesize[0] = frame->linesize[0];
> +    dst_data = (void **)frame->data;
> +    fmt = get_pixel_format(output);
> +
>      if (bytewidth < 0) {
>          return AVERROR(EINVAL);
>      }
> @@ -35,6 +45,18 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, 
> DNNData *output, void *log_ctx)
>          avpriv_report_missing_feature(log_ctx, "data type rather than 
> DNN_FLOAT");
>          return AVERROR(ENOSYS);
>      }
> +    if (fmt == AV_PIX_FMT_GBRP) {
> +        middle_data = malloc(frame_size * 3 * sizeof(uint8_t));
> +        if (!middle_data) {
> +            av_log(log_ctx, AV_LOG_ERROR, "Failed to malloc memory for 
> middle_data for "
> +                    "the conversion fmt:%s s:%dx%d -> fmt:%s 
> s:%dx%d\n",
> +                    av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32),  
> frame->width, frame->height,
> +                    
> av_get_pix_fmt_name(AV_PIX_FMT_GRAY8),frame->width, frame->height);
> +            return AVERROR(EINVAL);
> +        }
> +        dst_data = &middle_data;
> +        linesize[0] = frame->width * 3;
> +    }
> 
>      switch (frame->format) {
>      case AV_PIX_FMT_RGB24:
> @@ -51,12 +73,43 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, 
> DNNData *output, void *log_ctx)
>                  "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
>                  av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32), frame->width 
> * 3, frame->height,
>                  av_get_pix_fmt_name(AV_PIX_FMT_GRAY8),   frame->width 
> * 3, frame->height);
> +            av_freep(&middle_data);
>              return AVERROR(EINVAL);
>          }
>          sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t 
> *)output->data, 0, 0, 0},
>                             (const int[4]){frame->width * 3 * 
> sizeof(float), 0, 0, 0}, 0, frame->height,
> -                           (uint8_t * const*)frame->data, 
> frame->linesize);
> +                           (uint8_t * const*)dst_data, linesize);
>          sws_freeContext(sws_ctx);
> +        switch (fmt) {
> +        case AV_PIX_FMT_GBRP:
> +            sws_ctx = sws_getContext(frame->width,
> +                                     frame->height,
> +                                     AV_PIX_FMT_GBRP,
> +                                     frame->width,
> +                                     frame->height,
> +                                     frame->format,
> +                                     0, NULL, NULL, NULL);
> +            if (!sws_ctx) {
> +                av_log(log_ctx, AV_LOG_ERROR, "Impossible to create 
> scale context for the conversion "
> +                       "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
> +                       av_get_pix_fmt_name(AV_PIX_FMT_GBRP),  
> frame->width, frame->height,
> +                       
> av_get_pix_fmt_name(frame->format),frame->width, frame->height);
> +                av_freep(&middle_data);
> +                return AVERROR(EINVAL);
> +            }
> +            sws_scale(sws_ctx, (const uint8_t * const[4]){(uint8_t 
> *)dst_data[0] + frame_size * sizeof(uint8_t),
> +                                                          (uint8_t 
> *)dst_data[0] + frame_size * sizeof(uint8_t) * 2,
> +                                                          (uint8_t 
> *)dst_data[0], 0},
> +                      (const int [4]){frame->width * sizeof(uint8_t),
> +                                      frame->width * sizeof(uint8_t),
> +                                      frame->width * sizeof(uint8_t), 
> 0}
> +                      , 0, frame->height,
> +                      (uint8_t * const*)frame->data, frame->linesize);
> +            break;
> +        default:
> +            break;
> +        }
> +        av_freep(&middle_data);
>          return 0;
>      case AV_PIX_FMT_GRAYF32:
>          av_image_copy_plane(frame->data[0], frame->linesize[0],
> @@ -101,6 +154,14 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, 
> DNNData *input, void *log_ctx)
>  {
>      struct SwsContext *sws_ctx;
>      int bytewidth = av_image_get_linesize(frame->format, frame->width, 
> 0);
> +    int frame_size = frame->height * frame->width;
> +    int linesize[3];
> +    void **src_data, *middle_data = NULL;
> +    enum AVPixelFormat fmt;
> +    linesize[0] = frame->linesize[0];
> +    src_data = (void **)frame->data;
> +    fmt = get_pixel_format(input);
> +
>      if (bytewidth < 0) {
>          return AVERROR(EINVAL);
>      }
> @@ -112,6 +173,46 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, 
> DNNData *input, void *log_ctx)
>      switch (frame->format) {
>      case AV_PIX_FMT_RGB24:
>      case AV_PIX_FMT_BGR24:
> +        switch (fmt) {
> +        case AV_PIX_FMT_GBRP:
> +            middle_data = av_malloc(frame_size * 3 * sizeof(uint8_t));
> +            if (!middle_data) {
> +                av_log(log_ctx, AV_LOG_ERROR, "Failed to malloc memory 
> for middle_data for "
> +                       "the conversion fmt:%s s:%dx%d -> fmt:%s 
> s:%dx%d\n",
> +                       av_get_pix_fmt_name(frame->format),  
> frame->width, frame->height,
> +                       
> av_get_pix_fmt_name(AV_PIX_FMT_GBRP),frame->width, frame->height);
> +                return AVERROR(EINVAL);
> +            }
> +            sws_ctx = sws_getContext(frame->width,
> +                                     frame->height,
> +                                     frame->format,
> +                                     frame->width,
> +                                     frame->height,
> +                                     AV_PIX_FMT_GBRP,
> +                                     0, NULL, NULL, NULL);
> +            if (!sws_ctx) {
> +                av_log(log_ctx, AV_LOG_ERROR, "Impossible to create 
> scale context for the conversion "
> +                       "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
> +                       av_get_pix_fmt_name(frame->format),  
> frame->width, frame->height,
> +                       
> av_get_pix_fmt_name(AV_PIX_FMT_GBRP),frame->width, frame->height);
> +                av_freep(&middle_data);
> +                return AVERROR(EINVAL);
> +            }
> +            sws_scale(sws_ctx, (const uint8_t **)frame->data,
> +                      frame->linesize, 0, frame->height,
> +                      (uint8_t * const [4]){(uint8_t *)middle_data + 
> frame_size * sizeof(uint8_t),
> +                                            (uint8_t *)middle_data + 
> frame_size * sizeof(uint8_t) * 2,
> +                                            (uint8_t *)middle_data, 0},
> +                      (const int [4]){frame->width * sizeof(uint8_t),
> +                                      frame->width * sizeof(uint8_t),
> +                                      frame->width * sizeof(uint8_t), 
> 0});
> +            sws_freeContext(sws_ctx);
> +            src_data = &middle_data;
> +            linesize[0] = frame->width * 3;
> +            break;
> +        default:
> +            break;
> +        }
>          sws_ctx = sws_getContext(frame->width * 3,
>                                   frame->height,
>                                   AV_PIX_FMT_GRAY8,
> @@ -124,13 +225,15 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, 
> DNNData *input, void *log_ctx)
>                  "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
>                  av_get_pix_fmt_name(AV_PIX_FMT_GRAY8),  frame->width * 
> 3, frame->height,
>                  av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32),frame->width * 
> 3, frame->height);
> +            av_freep(&middle_data);
>              return AVERROR(EINVAL);
>          }
> -        sws_scale(sws_ctx, (const uint8_t **)frame->data,
> -                           frame->linesize, 0, frame->height,
> +        sws_scale(sws_ctx, (const uint8_t **)src_data,
> +                           linesize, 0, frame->height,
>                             (uint8_t * const [4]){input->data, 0, 0, 0},
>                             (const int [4]){frame->width * 3 * 
> sizeof(float), 0, 0, 0});
>          sws_freeContext(sws_ctx);
> +        av_freep(&middle_data);
>          break;
>      case AV_PIX_FMT_GRAYF32:
>          av_image_copy_plane(input->data, bytewidth,
> @@ -184,6 +287,14 @@ static enum AVPixelFormat get_pixel_format(DNNData 
> *data)
>              av_assert0(!"unsupported data pixel format.\n");
>              return AV_PIX_FMT_BGR24;
>          }
> +    } else if (data->dt == DNN_FLOAT) {
> +        switch (data->order) {
> +        case DCO_RGB_PLANAR:
> +            return AV_PIX_FMT_GBRP;
> +        default:
> +            av_assert0(!"unsupported data pixel format.\n");
> +            return AV_PIX_FMT_GBRP;
> +        }
>      }
> 
>      av_assert0(!"unsupported data type.\n");
> diff --git a/libavfilter/dnn_filter_common.c 
> b/libavfilter/dnn_filter_common.c
> index 5083e3de19..a4e1147fb9 100644
> --- a/libavfilter/dnn_filter_common.c
> +++ b/libavfilter/dnn_filter_common.c
> @@ -53,19 +53,31 @@ static char **separate_output_names(const char 
> *expr, const char *val_sep, int *
> 
>  int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, 
> AVFilterContext *filter_ctx)
>  {
> +    DNNBackendType backend = ctx->backend_type;
> +
>      if (!ctx->model_filename) {
>          av_log(filter_ctx, AV_LOG_ERROR, "model file for network is 
> not specified\n");
>          return AVERROR(EINVAL);
>      }
> -    if (!ctx->model_inputname) {
> -        av_log(filter_ctx, AV_LOG_ERROR, "input name of the model 
> network is not specified\n");
> -        return AVERROR(EINVAL);
> -    }
> 
> -    ctx->model_outputnames = 
> separate_output_names(ctx->model_outputnames_string, "&", 
> &ctx->nb_outputs);
> -    if (!ctx->model_outputnames) {
> -        av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output 
> names\n");
> -        return AVERROR(EINVAL);
> +    if (backend == DNN_TH) {
> +        if (ctx->model_inputname)
> +            av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do 
> not require inputname, "\
> +                                               "inputname will be 
> ignored.\n");
> +        if (ctx->model_outputnames)
> +            av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do 
> not require outputname(s), "\
> +                                               "all outputname(s) will 
> be ignored.\n");
> +        ctx->nb_outputs = 1;
> +    } else {
> +        if (!ctx->model_inputname) {
> +            av_log(filter_ctx, AV_LOG_ERROR, "input name of the model 
> network is not specified\n");
> +            return AVERROR(EINVAL);
> +        }
> +        ctx->model_outputnames = 
> separate_output_names(ctx->model_outputnames_string, "&", 
> &ctx->nb_outputs);
> +        if (!ctx->model_outputnames) {
> +            av_log(filter_ctx, AV_LOG_ERROR, "could not parse model 
> output names\n");
> +            return AVERROR(EINVAL);
> +        }
>      }
> 
>      ctx->dnn_module = ff_get_dnn_module(ctx->backend_type);
> @@ -113,8 +125,9 @@ int ff_dnn_get_input(DnnContext *ctx, DNNData *input)
> 
>  int ff_dnn_get_output(DnnContext *ctx, int input_width, int 
> input_height, int *output_width, int *output_height)
>  {
> +    const char *model_outputnames = ctx->backend_type == DNN_TH ? NULL 
> : ctx->model_outputnames[0];
>      return ctx->model->get_output(ctx->model->model, 
> ctx->model_inputname, input_width, input_height,
> -                                    (const char 
> *)ctx->model_outputnames[0], output_width, output_height);
> +                                  model_outputnames, output_width, 
> output_height);
>  }
> 
>  int ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame 
> *out_frame)
> diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
> index d94baa90c4..32698f788b 100644
> --- a/libavfilter/dnn_interface.h
> +++ b/libavfilter/dnn_interface.h
> @@ -32,7 +32,7 @@
> 
>  #define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!')
> 
> -typedef enum {DNN_NATIVE, DNN_TF, DNN_OV} DNNBackendType;
> +typedef enum {DNN_NATIVE, DNN_TF, DNN_OV, DNN_TH} DNNBackendType;
> 
>  typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType;
> 
> @@ -40,6 +40,7 @@ typedef enum {
>      DCO_NONE,
>      DCO_BGR_PACKED,
>      DCO_RGB_PACKED,
> +    DCO_RGB_PLANAR,
>  } DNNColorOrder;
> 
>  typedef enum {
> diff --git a/libavfilter/vf_dnn_processing.c 
> b/libavfilter/vf_dnn_processing.c
> index cac096a19f..ac1dc6e1d9 100644
> --- a/libavfilter/vf_dnn_processing.c
> +++ b/libavfilter/vf_dnn_processing.c
> @@ -52,6 +52,9 @@ static const AVOption dnn_processing_options[] = {
>  #endif
>  #if (CONFIG_LIBOPENVINO == 1)
>      { "openvino",    "openvino backend flag",      0,                  
>       AV_OPT_TYPE_CONST,     { .i64 = 2 },    0, 0, FLAGS, "backend" },
> +#endif
> +#if (CONFIG_LIBTORCH == 1)
> +    { "torch",       "torch backend flag",         0,                  
>       AV_OPT_TYPE_CONST,     { .i64 = 3 },    0, 0, FLAGS, "backend" },
>  #endif
>      DNN_COMMON_OPTIONS
>      { NULL }
> -- 
> 2.17.1
>
> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel at ffmpeg.org
> https://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
> To unsubscribe, visit link above, or email
> ffmpeg-devel-request at ffmpeg.org with subject "unsubscribe".

-- 
Jean-Baptiste Kempf -  President
+33 672 704 734


More information about the ffmpeg-devel mailing list