[FFmpeg-devel] [PATCH v3] libavfi/dnn: add LibTorch as one of DNN backend
Jean-Baptiste Kempf
jb at videolan.org
Wed Feb 21 03:44:11 EET 2024
Hello,
On Tue, 20 Feb 2024, at 05:48, wenbin.chen-at-intel.com at ffmpeg.org wrote:
> From: Wenbin Chen <wenbin.chen at intel.com>
>
> PyTorch is an open source machine learning framework that accelerates
OK for me
> the path from research prototyping to production deployment. Official
> websit: https://pytorch.org/. We call the C++ library of PyTorch as
websitE
> LibTorch, the same below.
>
> To build FFmpeg with LibTorch, please take following steps as reference:
> 1. download LibTorch C++ library in
> https://pytorch.org/get-started/locally/,
> please select C++/Java for language, and other options as your need.
> 2. unzip the file to your own dir, with command
> unzip libtorch-shared-with-deps-latest.zip -d your_dir
> 3. export libtorch_root/libtorch/include and
> libtorch_root/libtorch/include/torch/csrc/api/include to $PATH
> export libtorch_root/libtorch/lib/ to $LD_LIBRARY_PATH
> 4. config FFmpeg with ../configure --enable-libtorch
> --extra-cflag=-I/libtorch_root/libtorch/include
> --extra-cflag=-I/libtorch_root/libtorch/include/torch/csrc/api/include
> --extra-ldflags=-L/libtorch_root/libtorch/lib/
> 5. make
>
> To run FFmpeg DNN inference with LibTorch backend:
> ./ffmpeg -i input.jpg -vf
> dnn_processing=dnn_backend=torch:model=LibTorch_model.pt -y output.jpg
> The LibTorch_model.pt can be generated by Python with
> torch.jit.script() api. Please note, torch.jit.trace() is not
> recommanded, since it does not support ambiguous input size.
>
> Signed-off-by: Ting Fu <ting.fu at intel.com>
> Signed-off-by: Wenbin Chen <wenbin.chen at intel.com>
> ---
> configure | 5 +-
> libavfilter/dnn/Makefile | 1 +
> libavfilter/dnn/dnn_backend_torch.cpp | 597 ++++++++++++++++++++++++++
> libavfilter/dnn/dnn_interface.c | 5 +
> libavfilter/dnn_filter_common.c | 15 +-
> libavfilter/dnn_interface.h | 2 +-
> libavfilter/vf_dnn_processing.c | 3 +
> 7 files changed, 624 insertions(+), 4 deletions(-)
> create mode 100644 libavfilter/dnn/dnn_backend_torch.cpp
>
> diff --git a/configure b/configure
> index 2c635043dd..450ef54a80 100755
> --- a/configure
> +++ b/configure
> @@ -279,6 +279,7 @@ External library support:
> --enable-libtheora enable Theora encoding via libtheora [no]
> --enable-libtls enable LibreSSL (via libtls), needed for
> https support
> if openssl, gnutls or mbedtls is not used
> [no]
> + --enable-libtorch enable Torch as one DNN backend [no]
> --enable-libtwolame enable MP2 encoding via libtwolame [no]
> --enable-libuavs3d enable AVS3 decoding via libuavs3d [no]
> --enable-libv4l2 enable libv4l2/v4l-utils [no]
> @@ -1901,6 +1902,7 @@ EXTERNAL_LIBRARY_LIST="
> libtensorflow
> libtesseract
> libtheora
> + libtorch
> libtwolame
> libuavs3d
> libv4l2
> @@ -2781,7 +2783,7 @@ cbs_vp9_select="cbs"
> deflate_wrapper_deps="zlib"
> dirac_parse_select="golomb"
> dovi_rpu_select="golomb"
> -dnn_suggest="libtensorflow libopenvino"
> +dnn_suggest="libtensorflow libopenvino libtorch"
> dnn_deps="avformat swscale"
> error_resilience_select="me_cmp"
> evcparse_select="golomb"
> @@ -6886,6 +6888,7 @@ enabled libtensorflow && require
> libtensorflow tensorflow/c/c_api.h TF_Versi
> enabled libtesseract && require_pkg_config libtesseract tesseract
> tesseract/capi.h TessBaseAPICreate
> enabled libtheora && require libtheora theora/theoraenc.h
> th_info_init -ltheoraenc -ltheoradec -logg
> enabled libtls && require_pkg_config libtls libtls tls.h
> tls_configure
> +enabled libtorch && check_cxxflags -std=c++14 && require_cpp
> libtorch torch/torch.h "torch::Tensor" -ltorch -lc10 -ltorch_cpu
> -lstdc++ -lpthread
> enabled libtwolame && require libtwolame twolame.h twolame_init
> -ltwolame &&
> { check_lib libtwolame twolame.h
> twolame_encode_buffer_float32_interleaved -ltwolame ||
> die "ERROR: libtwolame must be
> installed and version must be >= 0.3.10"; }
> diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile
> index 5d5697ea42..3d09927c98 100644
> --- a/libavfilter/dnn/Makefile
> +++ b/libavfilter/dnn/Makefile
> @@ -6,5 +6,6 @@ OBJS-$(CONFIG_DNN) +=
> dnn/dnn_backend_common.o
>
> DNN-OBJS-$(CONFIG_LIBTENSORFLOW) += dnn/dnn_backend_tf.o
> DNN-OBJS-$(CONFIG_LIBOPENVINO) += dnn/dnn_backend_openvino.o
> +DNN-OBJS-$(CONFIG_LIBTORCH) += dnn/dnn_backend_torch.o
>
> OBJS-$(CONFIG_DNN) += $(DNN-OBJS-yes)
> diff --git a/libavfilter/dnn/dnn_backend_torch.cpp
> b/libavfilter/dnn/dnn_backend_torch.cpp
> new file mode 100644
> index 0000000000..54d3b309a1
> --- /dev/null
> +++ b/libavfilter/dnn/dnn_backend_torch.cpp
> @@ -0,0 +1,597 @@
> +/*
> + * Copyright (c) 2024
> + *
> + * This file is part of FFmpeg.
> + *
> + * FFmpeg is free software; you can redistribute it and/or
> + * modify it under the terms of the GNU Lesser General Public
> + * License as published by the Free Software Foundation; either
> + * version 2.1 of the License, or (at your option) any later version.
> + *
> + * FFmpeg is distributed in the hope that it will be useful,
> + * but WITHOUT ANY WARRANTY; without even the implied warranty of
> + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
> + * Lesser General Public License for more details.
> + *
> + * You should have received a copy of the GNU Lesser General Public
> + * License along with FFmpeg; if not, write to the Free Software
> + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
> 02110-1301 USA
> + */
> +
> +/**
> + * @file
> + * DNN Torch backend implementation.
> + */
> +
> +#include <torch/torch.h>
> +#include <torch/script.h>
> +
> +extern "C" {
> +#include "../internal.h"
> +#include "dnn_io_proc.h"
> +#include "dnn_backend_common.h"
> +#include "libavutil/opt.h"
> +#include "queue.h"
> +#include "safe_queue.h"
> +}
> +
> +typedef struct THOptions{
> + char *device_name;
> + int optimize;
> +} THOptions;
> +
> +typedef struct THContext {
> + const AVClass *c_class;
> + THOptions options;
> +} THContext;
> +
> +typedef struct THModel {
> + THContext ctx;
> + DNNModel *model;
> + torch::jit::Module *jit_model;
> + SafeQueue *request_queue;
> + Queue *task_queue;
> + Queue *lltask_queue;
> +} THModel;
> +
> +typedef struct THInferRequest {
> + torch::Tensor *output;
> + torch::Tensor *input_tensor;
> +} THInferRequest;
> +
> +typedef struct THRequestItem {
> + THInferRequest *infer_request;
> + LastLevelTaskItem *lltask;
> + DNNAsyncExecModule exec_module;
> +} THRequestItem;
> +
> +
> +#define OFFSET(x) offsetof(THContext, x)
> +#define FLAGS AV_OPT_FLAG_FILTERING_PARAM
> +static const AVOption dnn_th_options[] = {
> + { "device", "device to run model", OFFSET(options.device_name),
> AV_OPT_TYPE_STRING, { .str = "cpu" }, 0, 0, FLAGS },
> + { "optimize", "turn on graph executor optimization",
> OFFSET(options.optimize), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS},
> + { NULL }
> +};
> +
> +AVFILTER_DEFINE_CLASS(dnn_th);
> +
> +static int extract_lltask_from_task(TaskItem *task, Queue
> *lltask_queue)
> +{
> + THModel *th_model = (THModel *)task->model;
> + THContext *ctx = &th_model->ctx;
> + LastLevelTaskItem *lltask = (LastLevelTaskItem
> *)av_malloc(sizeof(*lltask));
> + if (!lltask) {
> + av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for
> LastLevelTaskItem\n");
> + return AVERROR(ENOMEM);
> + }
> + task->inference_todo = 1;
> + task->inference_done = 0;
> + lltask->task = task;
> + if (ff_queue_push_back(lltask_queue, lltask) < 0) {
> + av_log(ctx, AV_LOG_ERROR, "Failed to push back
> lltask_queue.\n");
> + av_freep(&lltask);
> + return AVERROR(ENOMEM);
> + }
> + return 0;
> +}
> +
> +static void th_free_request(THInferRequest *request)
> +{
> + if (!request)
> + return;
> + if (request->output) {
> + delete(request->output);
> + request->output = NULL;
> + }
> + if (request->input_tensor) {
> + delete(request->input_tensor);
> + request->input_tensor = NULL;
> + }
> + return;
> +}
> +
> +static inline void destroy_request_item(THRequestItem **arg)
> +{
> + THRequestItem *item;
> + if (!arg || !*arg) {
> + return;
> + }
> + item = *arg;
> + th_free_request(item->infer_request);
> + av_freep(&item->infer_request);
> + av_freep(&item->lltask);
> + ff_dnn_async_module_cleanup(&item->exec_module);
> + av_freep(arg);
> +}
> +
> +static void dnn_free_model_th(DNNModel **model)
> +{
> + THModel *th_model;
> + if (!model || !*model)
> + return;
> +
> + th_model = (THModel *) (*model)->model;
> + while (ff_safe_queue_size(th_model->request_queue) != 0) {
> + THRequestItem *item = (THRequestItem
> *)ff_safe_queue_pop_front(th_model->request_queue);
> + destroy_request_item(&item);
> + }
> + ff_safe_queue_destroy(th_model->request_queue);
> +
> + while (ff_queue_size(th_model->lltask_queue) != 0) {
> + LastLevelTaskItem *item = (LastLevelTaskItem
> *)ff_queue_pop_front(th_model->lltask_queue);
> + av_freep(&item);
> + }
> + ff_queue_destroy(th_model->lltask_queue);
> +
> + while (ff_queue_size(th_model->task_queue) != 0) {
> + TaskItem *item = (TaskItem
> *)ff_queue_pop_front(th_model->task_queue);
> + av_frame_free(&item->in_frame);
> + av_frame_free(&item->out_frame);
> + av_freep(&item);
> + }
> + ff_queue_destroy(th_model->task_queue);
> + delete th_model->jit_model;
> + av_opt_free(&th_model->ctx);
> + av_freep(&th_model);
> + av_freep(model);
> +}
> +
> +static int get_input_th(void *model, DNNData *input, const char
> *input_name)
> +{
> + input->dt = DNN_FLOAT;
> + input->order = DCO_RGB;
> + input->layout = DL_NCHW;
> + input->dims[0] = 1;
> + input->dims[1] = 3;
> + input->dims[2] = -1;
> + input->dims[3] = -1;
> + return 0;
> +}
> +
> +static void deleter(void *arg)
> +{
> + av_freep(&arg);
> +}
> +
> +static int fill_model_input_th(THModel *th_model, THRequestItem
> *request)
> +{
> + LastLevelTaskItem *lltask = NULL;
> + TaskItem *task = NULL;
> + THInferRequest *infer_request = NULL;
> + DNNData input = { 0 };
> + THContext *ctx = &th_model->ctx;
> + int ret, width_idx, height_idx, channel_idx;
> +
> + lltask = (LastLevelTaskItem
> *)ff_queue_pop_front(th_model->lltask_queue);
> + if (!lltask) {
> + ret = AVERROR(EINVAL);
> + goto err;
> + }
> + request->lltask = lltask;
> + task = lltask->task;
> + infer_request = request->infer_request;
> +
> + ret = get_input_th(th_model, &input, NULL);
> + if ( ret != 0) {
> + goto err;
> + }
> + width_idx = dnn_get_width_idx_by_layout(input.layout);
> + height_idx = dnn_get_height_idx_by_layout(input.layout);
> + channel_idx = dnn_get_channel_idx_by_layout(input.layout);
> + input.dims[height_idx] = task->in_frame->height;
> + input.dims[width_idx] = task->in_frame->width;
> + input.data = av_malloc(input.dims[height_idx] *
> input.dims[width_idx] *
> + input.dims[channel_idx] * sizeof(float));
> + if (!input.data)
> + return AVERROR(ENOMEM);
> + infer_request->input_tensor = new torch::Tensor();
> + infer_request->output = new torch::Tensor();
> +
> + switch (th_model->model->func_type) {
> + case DFT_PROCESS_FRAME:
> + input.scale = 255;
> + if (task->do_ioproc) {
> + if (th_model->model->frame_pre_proc != NULL) {
> + th_model->model->frame_pre_proc(task->in_frame,
> &input, th_model->model->filter_ctx);
> + } else {
> + ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx);
> + }
> + }
> + break;
> + default:
> + avpriv_report_missing_feature(NULL, "model function type %d",
> th_model->model->func_type);
> + break;
> + }
> + *infer_request->input_tensor = torch::from_blob(input.data,
> + {1, 1, input.dims[channel_idx], input.dims[height_idx],
> input.dims[width_idx]},
> + deleter, torch::kFloat32);
> + return 0;
> +
> +err:
> + th_free_request(infer_request);
> + return ret;
> +}
> +
> +static int th_start_inference(void *args)
> +{
> + THRequestItem *request = (THRequestItem *)args;
> + THInferRequest *infer_request = NULL;
> + LastLevelTaskItem *lltask = NULL;
> + TaskItem *task = NULL;
> + THModel *th_model = NULL;
> + THContext *ctx = NULL;
> + std::vector<torch::jit::IValue> inputs;
> + torch::NoGradGuard no_grad;
> +
> + if (!request) {
> + av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n");
> + return AVERROR(EINVAL);
> + }
> + infer_request = request->infer_request;
> + lltask = request->lltask;
> + task = lltask->task;
> + th_model = (THModel *)task->model;
> + ctx = &th_model->ctx;
> +
> + if (ctx->options.optimize)
> + torch::jit::setGraphExecutorOptimize(true);
> + else
> + torch::jit::setGraphExecutorOptimize(false);
> +
> + if (!infer_request->input_tensor || !infer_request->output) {
> + av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
> + return DNN_GENERIC_ERROR;
> + }
> + inputs.push_back(*infer_request->input_tensor);
> +
> + *infer_request->output =
> th_model->jit_model->forward(inputs).toTensor();
> +
> + return 0;
> +}
> +
> +static void infer_completion_callback(void *args) {
> + THRequestItem *request = (THRequestItem*)args;
> + LastLevelTaskItem *lltask = request->lltask;
> + TaskItem *task = lltask->task;
> + DNNData outputs = { 0 };
> + THInferRequest *infer_request = request->infer_request;
> + THModel *th_model = (THModel *)task->model;
> + torch::Tensor *output = infer_request->output;
> +
> + c10::IntArrayRef sizes = output->sizes();
> + outputs.order = DCO_RGB;
> + outputs.layout = DL_NCHW;
> + outputs.dt = DNN_FLOAT;
> + if (sizes.size() == 5) {
> + // 5 dimensions: [batch_size, frame_nubmer, channel, height,
> width]
> + // this format of data is normally used for video frame SR
> + outputs.dims[0] = sizes.at(0); // N
> + outputs.dims[1] = sizes.at(2); // C
> + outputs.dims[2] = sizes.at(3); // H
> + outputs.dims[3] = sizes.at(4); // W
> + } else {
> + avpriv_report_missing_feature(&th_model->ctx, "Support of this
> kind of model");
> + goto err;
> + }
> +
> + switch (th_model->model->func_type) {
> + case DFT_PROCESS_FRAME:
> + if (task->do_ioproc) {
> + outputs.scale = 255;
> + outputs.data = output->data_ptr();
> + if (th_model->model->frame_post_proc != NULL) {
> + th_model->model->frame_post_proc(task->out_frame,
> &outputs, th_model->model->filter_ctx);
> + } else {
> + ff_proc_from_dnn_to_frame(task->out_frame, &outputs,
> &th_model->ctx);
> + }
> + } else {
> + task->out_frame->width =
> outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)];
> + task->out_frame->height =
> outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)];
> + }
> + break;
> + default:
> + avpriv_report_missing_feature(&th_model->ctx, "model function
> type %d", th_model->model->func_type);
> + goto err;
> + }
> + task->inference_done++;
> + av_freep(&request->lltask);
> +err:
> + th_free_request(infer_request);
> +
> + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0)
> {
> + destroy_request_item(&request);
> + av_log(&th_model->ctx, AV_LOG_ERROR, "Unable to push back
> request_queue when failed to start inference.\n");
> + }
> +}
> +
> +static int execute_model_th(THRequestItem *request, Queue
> *lltask_queue)
> +{
> + THModel *th_model = NULL;
> + LastLevelTaskItem *lltask;
> + TaskItem *task = NULL;
> + int ret = 0;
> +
> + if (ff_queue_size(lltask_queue) == 0) {
> + destroy_request_item(&request);
> + return 0;
> + }
> +
> + lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue);
> + if (lltask == NULL) {
> + av_log(NULL, AV_LOG_ERROR, "Failed to get
> LastLevelTaskItem\n");
> + ret = AVERROR(EINVAL);
> + goto err;
> + }
> + task = lltask->task;
> + th_model = (THModel *)task->model;
> +
> + ret = fill_model_input_th(th_model, request);
> + if ( ret != 0) {
> + goto err;
> + }
> + if (task->async) {
> + avpriv_report_missing_feature(&th_model->ctx, "LibTorch
> async");
> + } else {
> + ret = th_start_inference((void *)(request));
> + if (ret != 0) {
> + goto err;
> + }
> + infer_completion_callback(request);
> + return (task->inference_done == task->inference_todo) ? 0 :
> DNN_GENERIC_ERROR;
> + }
> +
> +err:
> + th_free_request(request->infer_request);
> + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0)
> {
> + destroy_request_item(&request);
> + }
> + return ret;
> +}
> +
> +static int get_output_th(void *model, const char *input_name, int
> input_width, int input_height,
> + const char *output_name, int
> *output_width, int *output_height)
> +{
> + int ret = 0;
> + THModel *th_model = (THModel*) model;
> + THContext *ctx = &th_model->ctx;
> + TaskItem task = { 0 };
> + THRequestItem *request = NULL;
> + DNNExecBaseParams exec_params = {
> + .input_name = input_name,
> + .output_names = &output_name,
> + .nb_output = 1,
> + .in_frame = NULL,
> + .out_frame = NULL,
> + };
> + ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params,
> th_model, input_height, input_width, ctx);
> + if ( ret != 0) {
> + goto err;
> + }
> +
> + ret = extract_lltask_from_task(&task, th_model->lltask_queue);
> + if ( ret != 0) {
> + av_log(ctx, AV_LOG_ERROR, "unable to extract last level task
> from task.\n");
> + goto err;
> + }
> +
> + request = (THRequestItem*)
> ff_safe_queue_pop_front(th_model->request_queue);
> + if (!request) {
> + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
> + ret = AVERROR(EINVAL);
> + goto err;
> + }
> +
> + ret = execute_model_th(request, th_model->lltask_queue);
> + *output_width = task.out_frame->width;
> + *output_height = task.out_frame->height;
> +
> +err:
> + av_frame_free(&task.out_frame);
> + av_frame_free(&task.in_frame);
> + return ret;
> +}
> +
> +static THInferRequest *th_create_inference_request(void)
> +{
> + THInferRequest *request = (THInferRequest
> *)av_malloc(sizeof(THInferRequest));
> + if (!request) {
> + return NULL;
> + }
> + request->input_tensor = NULL;
> + request->output = NULL;
> + return request;
> +}
> +
> +static DNNModel *dnn_load_model_th(const char *model_filename,
> DNNFunctionType func_type, const char *options, AVFilterContext
> *filter_ctx)
> +{
> + DNNModel *model = NULL;
> + THModel *th_model = NULL;
> + THRequestItem *item = NULL;
> + THContext *ctx;
> +
> + model = (DNNModel *)av_mallocz(sizeof(DNNModel));
> + if (!model) {
> + return NULL;
> + }
> +
> + th_model = (THModel *)av_mallocz(sizeof(THModel));
> + if (!th_model) {
> + av_freep(&model);
> + return NULL;
> + }
> + th_model->model = model;
> + model->model = th_model;
> + th_model->ctx.c_class = &dnn_th_class;
> + ctx = &th_model->ctx;
> + //parse options
> + av_opt_set_defaults(ctx);
> + if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0) {
> + av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n",
> options);
> + return NULL;
> + }
> +
> + c10::Device device = c10::Device(ctx->options.device_name);
> + if (!device.is_cpu()) {
> + av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n",
> ctx->options.device_name);
> + goto fail;
> + }
> +
> + try {
> + th_model->jit_model = new torch::jit::Module;
> + (*th_model->jit_model) = torch::jit::load(model_filename);
> + } catch (const c10::Error& e) {
> + av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
> + goto fail;
> + }
> +
> + th_model->request_queue = ff_safe_queue_create();
> + if (!th_model->request_queue) {
> + goto fail;
> + }
> +
> + item = (THRequestItem *)av_mallocz(sizeof(THRequestItem));
> + if (!item) {
> + goto fail;
> + }
> + item->lltask = NULL;
> + item->infer_request = th_create_inference_request();
> + if (!item->infer_request) {
> + av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for
> Torch inference request\n");
> + goto fail;
> + }
> + item->exec_module.start_inference = &th_start_inference;
> + item->exec_module.callback = &infer_completion_callback;
> + item->exec_module.args = item;
> +
> + if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) {
> + goto fail;
> + }
> + item = NULL;
> +
> + th_model->task_queue = ff_queue_create();
> + if (!th_model->task_queue) {
> + goto fail;
> + }
> +
> + th_model->lltask_queue = ff_queue_create();
> + if (!th_model->lltask_queue) {
> + goto fail;
> + }
> +
> + model->get_input = &get_input_th;
> + model->get_output = &get_output_th;
> + model->options = NULL;
> + model->filter_ctx = filter_ctx;
> + model->func_type = func_type;
> + return model;
> +
> +fail:
> + if (item) {
> + destroy_request_item(&item);
> + av_freep(&item);
> + }
> + dnn_free_model_th(&model);
> + return NULL;
> +}
> +
> +static int dnn_execute_model_th(const DNNModel *model,
> DNNExecBaseParams *exec_params)
> +{
> + THModel *th_model = (THModel *)model->model;
> + THContext *ctx = &th_model->ctx;
> + TaskItem *task;
> + THRequestItem *request;
> + int ret = 0;
> +
> + ret = ff_check_exec_params(ctx, DNN_TH, model->func_type,
> exec_params);
> + if (ret != 0) {
> + av_log(ctx, AV_LOG_ERROR, "exec parameter checking fail.\n");
> + return ret;
> + }
> +
> + task = (TaskItem *)av_malloc(sizeof(TaskItem));
> + if (!task) {
> + av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task
> item.\n");
> + return AVERROR(ENOMEM);
> + }
> +
> + ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1);
> + if (ret != 0) {
> + av_freep(&task);
> + av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n");
> + return ret;
> + }
> +
> + ret = ff_queue_push_back(th_model->task_queue, task);
> + if (ret < 0) {
> + av_freep(&task);
> + av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n");
> + return ret;
> + }
> +
> + ret = extract_lltask_from_task(task, th_model->lltask_queue);
> + if (ret != 0) {
> + av_log(ctx, AV_LOG_ERROR, "unable to extract last level task
> from task.\n");
> + return ret;
> + }
> +
> + request = (THRequestItem
> *)ff_safe_queue_pop_front(th_model->request_queue);
> + if (!request) {
> + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
> + return AVERROR(EINVAL);
> + }
> +
> + return execute_model_th(request, th_model->lltask_queue);
> +}
> +
> +static DNNAsyncStatusType dnn_get_result_th(const DNNModel *model,
> AVFrame **in, AVFrame **out)
> +{
> + THModel *th_model = (THModel *)model->model;
> + return ff_dnn_get_result_common(th_model->task_queue, in, out);
> +}
> +
> +static int dnn_flush_th(const DNNModel *model)
> +{
> + THModel *th_model = (THModel *)model->model;
> + THRequestItem *request;
> +
> + if (ff_queue_size(th_model->lltask_queue) == 0)
> + // no pending task need to flush
> + return 0;
> +
> + request = (THRequestItem
> *)ff_safe_queue_pop_front(th_model->request_queue);
> + if (!request) {
> + av_log(&th_model->ctx, AV_LOG_ERROR, "unable to get infer
> request.\n");
> + return AVERROR(EINVAL);
> + }
> +
> + return execute_model_th(request, th_model->lltask_queue);
> +}
> +
> +extern const DNNModule ff_dnn_backend_torch = {
> + .load_model = dnn_load_model_th,
> + .execute_model = dnn_execute_model_th,
> + .get_result = dnn_get_result_th,
> + .flush = dnn_flush_th,
> + .free_model = dnn_free_model_th,
> +};
> diff --git a/libavfilter/dnn/dnn_interface.c
> b/libavfilter/dnn/dnn_interface.c
> index e843826aa6..b9f71aea53 100644
> --- a/libavfilter/dnn/dnn_interface.c
> +++ b/libavfilter/dnn/dnn_interface.c
> @@ -28,6 +28,7 @@
>
> extern const DNNModule ff_dnn_backend_openvino;
> extern const DNNModule ff_dnn_backend_tf;
> +extern const DNNModule ff_dnn_backend_torch;
>
> const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void
> *log_ctx)
> {
> @@ -40,6 +41,10 @@ const DNNModule *ff_get_dnn_module(DNNBackendType
> backend_type, void *log_ctx)
> case DNN_OV:
> return &ff_dnn_backend_openvino;
> #endif
> + #if (CONFIG_LIBTORCH == 1)
> + case DNN_TH:
> + return &ff_dnn_backend_torch;
> + #endif
> default:
> av_log(log_ctx, AV_LOG_ERROR,
> "Module backend_type %d is not supported or
> enabled.\n",
> diff --git a/libavfilter/dnn_filter_common.c
> b/libavfilter/dnn_filter_common.c
> index f012d450a2..7d194c9ade 100644
> --- a/libavfilter/dnn_filter_common.c
> +++ b/libavfilter/dnn_filter_common.c
> @@ -53,12 +53,22 @@ static char **separate_output_names(const char
> *expr, const char *val_sep, int *
>
> int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type,
> AVFilterContext *filter_ctx)
> {
> + DNNBackendType backend = ctx->backend_type;
> +
> if (!ctx->model_filename) {
> av_log(filter_ctx, AV_LOG_ERROR, "model file for network is
> not specified\n");
> return AVERROR(EINVAL);
> }
>
> - if (ctx->backend_type == DNN_TF) {
> + if (backend == DNN_TH) {
> + if (ctx->model_inputname)
> + av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do
> not require inputname, "\
> + "inputname will be
> ignored.\n");
> + if (ctx->model_outputnames)
> + av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do
> not require outputname(s), "\
> + "all outputname(s) will
> be ignored.\n");
> + ctx->nb_outputs = 1;
> + } else if (backend == DNN_TF) {
> if (!ctx->model_inputname) {
> av_log(filter_ctx, AV_LOG_ERROR, "input name of the model
> network is not specified\n");
> return AVERROR(EINVAL);
> @@ -115,7 +125,8 @@ int ff_dnn_get_input(DnnContext *ctx, DNNData
> *input)
>
> int ff_dnn_get_output(DnnContext *ctx, int input_width, int
> input_height, int *output_width, int *output_height)
> {
> - char * output_name = ctx->model_outputnames ?
> ctx->model_outputnames[0] : NULL;
> + char * output_name = ctx->model_outputnames && ctx->backend_type
> != DNN_TH ?
> + ctx->model_outputnames[0] : NULL;
> return ctx->model->get_output(ctx->model->model,
> ctx->model_inputname, input_width, input_height,
> (const char *)output_name,
> output_width, output_height);
> }
> diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
> index 852d88baa8..63f492e690 100644
> --- a/libavfilter/dnn_interface.h
> +++ b/libavfilter/dnn_interface.h
> @@ -32,7 +32,7 @@
>
> #define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!')
>
> -typedef enum {DNN_TF = 1, DNN_OV} DNNBackendType;
> +typedef enum {DNN_TF = 1, DNN_OV, DNN_TH} DNNBackendType;
>
> typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType;
>
> diff --git a/libavfilter/vf_dnn_processing.c
> b/libavfilter/vf_dnn_processing.c
> index e7d21eef32..fdac31665e 100644
> --- a/libavfilter/vf_dnn_processing.c
> +++ b/libavfilter/vf_dnn_processing.c
> @@ -50,6 +50,9 @@ static const AVOption dnn_processing_options[] = {
> #endif
> #if (CONFIG_LIBOPENVINO == 1)
> { "openvino", "openvino backend flag", 0,
> AV_OPT_TYPE_CONST, { .i64 = DNN_OV }, 0, 0, FLAGS, .unit =
> "backend" },
> +#endif
> +#if (CONFIG_LIBTORCH == 1)
> + { "torch", "torch backend flag", 0,
> AV_OPT_TYPE_CONST, { .i64 = DNN_TH }, 0, 0, FLAGS,
> "backend" },
> #endif
> DNN_COMMON_OPTIONS
> { NULL }
> --
> 2.34.1
>
> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel at ffmpeg.org
> https://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
> To unsubscribe, visit link above, or email
> ffmpeg-devel-request at ffmpeg.org with subject "unsubscribe".
--
Jean-Baptiste Kempf - President
+33 672 704 734
More information about the ffmpeg-devel
mailing list