[FFmpeg-devel] [PATCH 2/2] libavfi/dnn: add LibTorch as one of DNN backend
Ting Fu
ting.fu at intel.com
Mon May 23 12:29:18 EEST 2022
PyTorch is an open source machine learning framework that accelerates
the path from research prototyping to production deployment. Official
websit: https://pytorch.org/. We call the C++ library of PyTorch as
LibTorch, the same below.
To build FFmpeg with LibTorch, please take following steps as reference:
1. download LibTorch C++ library in https://pytorch.org/get-started/locally/,
please select C++/Java for language, and other options as your need.
2. unzip the file to your own dir, with command
unzip libtorch-shared-with-deps-latest.zip -d your_dir
3. export libtorch_root/libtorch/include and
libtorch_root/libtorch/include/torch/csrc/api/include to $PATH
export libtorch_root/libtorch/lib/ to $LD_LIBRARY_PATH
4. config FFmpeg with ../configure --enable-libtorch --extra-cflag=-I/libtorch_root/libtorch/include --extra-cflag=-I/libtorch_root/libtorch/include/torch/csrc/api/include --extra-ldflags=-L/libtorch_root/libtorch/lib/
5. make
To run FFmpeg DNN inference with LibTorch backend:
./ffmpeg -i input.jpg -vf dnn_processing=dnn_backend=torch:model=LibTorch_model.pt -y output.jpg
The LibTorch_model.pt can be generated by Python with torch.jit.script() api. Please note, torch.jit.trace() is not recommanded, since it does not support ambiguous input size.
Signed-off-by: Ting Fu <ting.fu at intel.com>
---
configure | 7 +-
libavfilter/dnn/Makefile | 1 +
libavfilter/dnn/dnn_backend_torch.cpp | 567 ++++++++++++++++++++++++++
libavfilter/dnn/dnn_backend_torch.h | 47 +++
libavfilter/dnn/dnn_interface.c | 12 +
libavfilter/dnn/dnn_io_proc.c | 117 +++++-
libavfilter/dnn_filter_common.c | 31 +-
libavfilter/dnn_interface.h | 3 +-
libavfilter/vf_dnn_processing.c | 3 +
9 files changed, 774 insertions(+), 14 deletions(-)
create mode 100644 libavfilter/dnn/dnn_backend_torch.cpp
create mode 100644 libavfilter/dnn/dnn_backend_torch.h
diff --git a/configure b/configure
index f115b21064..85ce3e67a3 100755
--- a/configure
+++ b/configure
@@ -279,6 +279,7 @@ External library support:
--enable-libtheora enable Theora encoding via libtheora [no]
--enable-libtls enable LibreSSL (via libtls), needed for https support
if openssl, gnutls or mbedtls is not used [no]
+ --enable-libtorch enable Torch as one DNN backend
--enable-libtwolame enable MP2 encoding via libtwolame [no]
--enable-libuavs3d enable AVS3 decoding via libuavs3d [no]
--enable-libv4l2 enable libv4l2/v4l-utils [no]
@@ -1850,6 +1851,7 @@ EXTERNAL_LIBRARY_LIST="
libopus
libplacebo
libpulse
+ libtorch
librabbitmq
librav1e
librist
@@ -2719,7 +2721,7 @@ dct_select="rdft"
deflate_wrapper_deps="zlib"
dirac_parse_select="golomb"
dovi_rpu_select="golomb"
-dnn_suggest="libtensorflow libopenvino"
+dnn_suggest="libtensorflow libopenvino libtorch"
dnn_deps="avformat swscale"
error_resilience_select="me_cmp"
faandct_deps="faan"
@@ -6600,6 +6602,7 @@ enabled libopus && {
}
enabled libplacebo && require_pkg_config libplacebo "libplacebo >= 4.192.0" libplacebo/vulkan.h pl_vulkan_create
enabled libpulse && require_pkg_config libpulse libpulse pulse/pulseaudio.h pa_context_new
+enabled libtorch && add_cppflags -D_GLIBCXX_USE_CXX11_ABI=0 && check_cxxflags -std=c++14 && require_cpp libtorch torch/torch.h "torch::Tensor" -ltorch -lc10 -ltorch_cpu -lstdc++ -lpthread
enabled librabbitmq && require_pkg_config librabbitmq "librabbitmq >= 0.7.1" amqp.h amqp_new_connection
enabled librav1e && require_pkg_config librav1e "rav1e >= 0.4.0" rav1e.h rav1e_context_new
enabled librist && require_pkg_config librist "librist >= 0.2" librist/librist.h rist_receiver_create
@@ -7025,6 +7028,8 @@ check_disable_warning -Wno-pointer-sign
check_disable_warning -Wno-unused-const-variable
check_disable_warning -Wno-bool-operation
check_disable_warning -Wno-char-subscripts
+#this option is for supress redundant-decls warning in compile libtorch
+check_disable_warning -Wno-redundant-decls
check_disable_warning_headers(){
warning_flag=-W${1#-Wno-}
diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile
index 4cfbce0efc..d44dcb847e 100644
--- a/libavfilter/dnn/Makefile
+++ b/libavfilter/dnn/Makefile
@@ -16,5 +16,6 @@ OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layer_mat
DNN-OBJS-$(CONFIG_LIBTENSORFLOW) += dnn/dnn_backend_tf.o
DNN-OBJS-$(CONFIG_LIBOPENVINO) += dnn/dnn_backend_openvino.o
+DNN-OBJS-$(CONFIG_LIBTORCH) += dnn/dnn_backend_torch.o
OBJS-$(CONFIG_DNN) += $(DNN-OBJS-yes)
diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp
new file mode 100644
index 0000000000..86cc018fbc
--- /dev/null
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -0,0 +1,567 @@
+/*
+ * Copyright (c) 2022
+ *
+ * This file is part of FFmpeg.
+ *
+ * FFmpeg is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * FFmpeg is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with FFmpeg; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+/**
+ * @file
+ * DNN Torch backend implementation.
+ */
+
+#include <torch/torch.h>
+#include <torch/script.h>
+#include "dnn_backend_torch.h"
+
+extern "C" {
+#include "dnn_io_proc.h"
+#include "../internal.h"
+#include "dnn_backend_common.h"
+#include "libavutil/opt.h"
+#include "queue.h"
+#include "safe_queue.h"
+}
+
+typedef struct THOptions{
+ char *device_name;
+ c10::DeviceType device_type;
+} THOptions;
+
+typedef struct THContext {
+ const AVClass *c_class;
+ THOptions options;
+} THContext;
+
+typedef struct THModel {
+ THContext ctx;
+ DNNModel *model;
+ torch::jit::Module jit_model;
+ SafeQueue *request_queue;
+ Queue *task_queue;
+ Queue *lltask_queue;
+} THModel;
+
+typedef struct THInferRequest {
+ torch::Tensor *output;
+ torch::Tensor *input_tensor;
+} THInferRequest;
+
+typedef struct THRequestItem {
+ THInferRequest *infer_request;
+ LastLevelTaskItem *lltask;
+ DNNAsyncExecModule exec_module;
+} THRequestItem;
+
+
+#define OFFSET(x) offsetof(THContext, x)
+#define FLAGS AV_OPT_FLAG_FILTERING_PARAM
+static const AVOption dnn_th_options[] = {
+ { "device", "device to run model", OFFSET(options.device_name), AV_OPT_TYPE_STRING, { .str = "cpu" }, 0, 0, FLAGS },
+ { NULL }
+};
+
+AVFILTER_DEFINE_CLASS(dnn_th);
+
+static int execute_model_th(THRequestItem *request, Queue *lltask_queue);
+static int th_start_inference(void *args);
+static void infer_completion_callback(void *args);
+
+static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue)
+{
+ THModel *th_model = (THModel *)task->model;
+ THContext *ctx = &th_model->ctx;
+ LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask));
+ if (!lltask) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n");
+ return AVERROR(ENOMEM);
+ }
+ task->inference_todo = 1;
+ task->inference_done = 0;
+ lltask->task = task;
+ if (ff_queue_push_back(lltask_queue, lltask) < 0) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n");
+ av_freep(&lltask);
+ return AVERROR(ENOMEM);
+ }
+ return 0;
+}
+
+static int get_input_th(void *model, DNNData *input, const char *input_name)
+{
+ input->dt = DNN_FLOAT;
+ input->order = DCO_RGB_PLANAR;
+ input->height = -1;
+ input->width = -1;
+ input->channels = 3;
+ return 0;
+}
+
+static int get_output_th(void *model, const char *input_name, int input_width, int input_height,
+ const char *output_name, int *output_width, int *output_height)
+{
+ int ret = 0;
+ THModel *th_model = (THModel*) model;
+ THContext *ctx = &th_model->ctx;
+ TaskItem task;
+ THRequestItem *request;
+ DNNExecBaseParams exec_params = {
+ .input_name = input_name,
+ .output_names = &output_name,
+ .nb_output = 1,
+ .in_frame = NULL,
+ .out_frame = NULL,
+ };
+ ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, th_model, input_height, input_width, ctx);
+ if ( ret != 0) {
+ goto err;
+ }
+
+ ret = extract_lltask_from_task(&task, th_model->lltask_queue);
+ if ( ret != 0) {
+ av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n");
+ goto err;
+ }
+
+ request = (THRequestItem*) ff_safe_queue_pop_front(th_model->request_queue);
+ if (!request) {
+ av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+ ret = AVERROR(EINVAL);
+ goto err;
+ }
+
+ ret = execute_model_th(request, th_model->lltask_queue);
+ *output_width = task.out_frame->width;
+ *output_height = task.out_frame->height;
+
+err:
+ av_frame_free(&task.out_frame);
+ av_frame_free(&task.in_frame);
+ return ret;
+}
+
+static void th_free_request(THInferRequest *request)
+{
+ if (!request)
+ return;
+ if (request->output) {
+ delete(request->output);
+ request->output = NULL;
+ }
+ if (request->input_tensor) {
+ delete(request->input_tensor);
+ request->input_tensor = NULL;
+ }
+ return;
+}
+
+static inline void destroy_request_item(THRequestItem **arg)
+{
+ THRequestItem *item;
+ if (!arg || !*arg) {
+ return;
+ }
+ item = *arg;
+ th_free_request(item->infer_request);
+ av_freep(&item->infer_request);
+ av_freep(&item->lltask);
+ ff_dnn_async_module_cleanup(&item->exec_module);
+ av_freep(arg);
+}
+
+static THInferRequest *th_create_inference_request(void)
+{
+ THInferRequest *request = (THInferRequest *)av_malloc(sizeof(THInferRequest));
+ if (!request) {
+ return NULL;
+ }
+ request->input_tensor = NULL;
+ request->output = NULL;
+ return request;
+}
+
+DNNModel *ff_dnn_load_model_th(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
+{
+ DNNModel *model = NULL;
+ THModel *th_model = NULL;
+ THRequestItem *item = NULL;
+ THContext *ctx;
+
+ model = (DNNModel *)av_mallocz(sizeof(DNNModel));
+ if (!model) {
+ return NULL;
+ }
+
+ th_model = (THModel *)av_mallocz(sizeof(THModel));
+ if (!th_model) {
+ av_freep(&model);
+ return NULL;
+ }
+
+ th_model->ctx.c_class = &dnn_th_class;
+ ctx = &th_model->ctx;
+ //parse options
+ av_opt_set_defaults(ctx);
+ if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options);
+ return NULL;
+ }
+
+ c10::Device device = c10::Device(ctx->options.device_name);
+ if (device.is_cpu()) {
+ ctx->options.device_type = torch::kCPU;
+ } else {
+ av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", ctx->options.device_name);
+ goto fail;
+ }
+
+ try {
+ th_model->jit_model = torch::jit::load(model_filename, device);
+ } catch (const c10::Error& e) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
+ goto fail;
+ }
+
+ th_model->request_queue = ff_safe_queue_create();
+ if (!th_model->request_queue) {
+ goto fail;
+ }
+
+ item = (THRequestItem *)av_mallocz(sizeof(THRequestItem));
+ if (!item) {
+ goto fail;
+ }
+ item->lltask = NULL;
+ item->infer_request = th_create_inference_request();
+ if (!item->infer_request) {
+ av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch inference request\n");
+ goto fail;
+ }
+ item->exec_module.start_inference = &th_start_inference;
+ item->exec_module.callback = &infer_completion_callback;
+ item->exec_module.args = item;
+
+ if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) {
+ goto fail;
+ }
+
+ th_model->task_queue = ff_queue_create();
+ if (!th_model->task_queue) {
+ goto fail;
+ }
+
+ th_model->lltask_queue = ff_queue_create();
+ if (!th_model->lltask_queue) {
+ goto fail;
+ }
+
+ th_model->model = model;
+ model->model = th_model;
+ model->get_input = &get_input_th;
+ model->get_output = &get_output_th;
+ model->options = NULL;
+ model->filter_ctx = filter_ctx;
+ model->func_type = func_type;
+ return model;
+
+fail:
+ destroy_request_item(&item);
+ ff_queue_destroy(th_model->task_queue);
+ ff_queue_destroy(th_model->lltask_queue);
+ ff_safe_queue_destroy(th_model->request_queue);
+ av_freep(&th_model);
+ av_freep(&model);
+ av_freep(&item);
+ return NULL;
+}
+
+static int fill_model_input_th(THModel *th_model, THRequestItem *request)
+{
+ LastLevelTaskItem *lltask = NULL;
+ TaskItem *task = NULL;
+ THInferRequest *infer_request = NULL;
+ DNNData input;
+ THContext *ctx = &th_model->ctx;
+ int ret;
+
+ lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue);
+ if (!lltask) {
+ ret = AVERROR(EINVAL);
+ goto err;
+ }
+ request->lltask = lltask;
+ task = lltask->task;
+ infer_request = request->infer_request;
+
+ ret = get_input_th(th_model, &input, NULL);
+ if ( ret != 0) {
+ goto err;
+ }
+
+ input.height = task->in_frame->height;
+ input.width = task->in_frame->width;
+ input.data = malloc(input.height * input.width * 3 * sizeof(float));
+ if (!input.data)
+ return AVERROR(ENOMEM);
+ infer_request->input_tensor = new torch::Tensor();
+ infer_request->output = new torch::Tensor();
+
+ switch (th_model->model->func_type) {
+ case DFT_PROCESS_FRAME:
+ if (task->do_ioproc) {
+ if (th_model->model->frame_pre_proc != NULL) {
+ th_model->model->frame_pre_proc(task->in_frame, &input, th_model->model->filter_ctx);
+ } else {
+ ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx);
+ }
+ }
+ break;
+ default:
+ avpriv_report_missing_feature(NULL, "model function type %d", th_model->model->func_type);
+ break;
+ }
+ *infer_request->input_tensor = torch::from_blob(input.data, {1, 1, 3, input.height, input.width},
+ torch::kFloat32);
+ return 0;
+
+err:
+ th_free_request(infer_request);
+ return ret;
+}
+
+static int th_start_inference(void *args)
+{
+ THRequestItem *request = (THRequestItem *)args;
+ THInferRequest *infer_request = NULL;
+ LastLevelTaskItem *lltask = NULL;
+ TaskItem *task = NULL;
+ THModel *th_model = NULL;
+ THContext *ctx = NULL;
+ std::vector<torch::jit::IValue> inputs;
+
+ if (!request) {
+ av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n");
+ return AVERROR(EINVAL);
+ }
+ infer_request = request->infer_request;
+ lltask = request->lltask;
+ task = lltask->task;
+ th_model = (THModel *)task->model;
+ ctx = &th_model->ctx;
+
+ if (!infer_request->input_tensor || !infer_request->output) {
+ av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
+ return DNN_GENERIC_ERROR;
+ }
+ inputs.push_back(*infer_request->input_tensor);
+
+ auto parameters = th_model->jit_model.parameters();
+ auto para = *(parameters.begin());
+
+ *infer_request->output = th_model->jit_model.forward(inputs).toTensor();
+
+ return 0;
+}
+
+static void infer_completion_callback(void *args) {
+ THRequestItem *request = (THRequestItem*)args;
+ LastLevelTaskItem *lltask = request->lltask;
+ TaskItem *task = lltask->task;
+ DNNData outputs;
+ THInferRequest *infer_request = request->infer_request;
+ THModel *th_model = (THModel *)task->model;
+ torch::Tensor *output = infer_request->output;
+
+ c10::IntArrayRef sizes = output->sizes();
+ assert(sizes.size == 5);
+ outputs.order = DCO_RGB_PLANAR;
+ outputs.height = sizes.at(3);
+ outputs.width = sizes.at(4);
+ outputs.dt = DNN_FLOAT;
+ outputs.channels = 3;
+
+ switch (th_model->model->func_type) {
+ case DFT_PROCESS_FRAME:
+ if (task->do_ioproc) {
+ outputs.data = output->data_ptr();
+ if (th_model->model->frame_post_proc != NULL) {
+ th_model->model->frame_post_proc(task->out_frame, &outputs, th_model->model->filter_ctx);
+ } else {
+ ff_proc_from_dnn_to_frame(task->out_frame, &outputs, &th_model->ctx);
+ }
+ } else {
+ task->out_frame->width = outputs.width;
+ task->out_frame->height = outputs.height;
+ }
+ break;
+ default:
+ avpriv_report_missing_feature(&th_model->ctx, "model function type %d", th_model->model->func_type);
+ goto err;
+ }
+ task->inference_done++;
+err:
+ th_free_request(infer_request);
+
+ if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) {
+ destroy_request_item(&request);
+ av_log(&th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue when failed to start inference.\n");
+ }
+}
+
+static int execute_model_th(THRequestItem *request, Queue *lltask_queue)
+{
+ THModel *th_model = NULL;
+ LastLevelTaskItem *lltask;
+ TaskItem *task = NULL;
+ int ret = 0;
+
+ if (ff_queue_size(lltask_queue) == 0) {
+ destroy_request_item(&request);
+ return 0;
+ }
+
+ lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue);
+ if (lltask == NULL) {
+ av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n");
+ ret = AVERROR(EINVAL);
+ goto err;
+ }
+ task = lltask->task;
+ th_model = (THModel *)task->model;
+
+ ret = fill_model_input_th(th_model, request);
+ if ( ret != 0) {
+ goto err;
+ }
+ if (task->async) {
+ avpriv_report_missing_feature(&th_model->ctx, "LibTorch async");
+ } else {
+ ret = th_start_inference((void *)(request));
+ if (ret != 0) {
+ goto err;
+ }
+ infer_completion_callback(request);
+ return (task->inference_done == task->inference_todo) ? 0 : DNN_GENERIC_ERROR;
+ }
+
+err:
+ th_free_request(request->infer_request);
+ if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) {
+ destroy_request_item(&request);
+ }
+ return ret;
+}
+
+int ff_dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params)
+{
+ THModel *th_model = (THModel *)model->model;
+ THContext *ctx = &th_model->ctx;
+ TaskItem *task;
+ THRequestItem *request;
+ int ret = 0;
+
+ ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params);
+ if (ret != 0) {
+ return ret;
+ }
+
+ task = (TaskItem *)av_malloc(sizeof(TaskItem));
+ if (!task) {
+ av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n");
+ return AVERROR(ENOMEM);
+ }
+
+ ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1);
+ if (ret != 0) {
+ av_freep(&task);
+ av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n");
+ return ret;
+ }
+
+ ret = ff_queue_push_back(th_model->task_queue, task);
+ if (ret < 0) {
+ av_freep(&task);
+ av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n");
+ return ret;
+ }
+
+ ret = extract_lltask_from_task(task, th_model->lltask_queue);
+ if (ret != 0) {
+ av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n");
+ return ret;
+ }
+
+ request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue);
+ if (!request) {
+ av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+ return AVERROR(EINVAL);
+ }
+
+ return execute_model_th(request, th_model->lltask_queue);
+}
+
+
+int ff_dnn_flush_th(const DNNModel *model)
+{
+ THModel *th_model = (THModel *)model->model;
+ THRequestItem *request;
+
+ if (ff_queue_size(th_model->lltask_queue) == 0) {
+ // no pending task need to flush
+ return 0;
+ }
+ request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue);
+ if (!request) {
+ av_log(&th_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+ return AVERROR(EINVAL);
+ }
+
+ return execute_model_th(request, th_model->lltask_queue);
+}
+
+DNNAsyncStatusType ff_dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out)
+{
+ THModel *th_model = (THModel *)model->model;
+ return ff_dnn_get_result_common(th_model->task_queue, in, out);
+}
+
+void ff_dnn_free_model_th(DNNModel **model)
+{
+ THModel *th_model;
+ if(*model) {
+ th_model = (THModel *) (*model)->model;
+ while (ff_safe_queue_size(th_model->request_queue) != 0) {
+ THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue);
+ destroy_request_item(&item);
+ }
+ ff_safe_queue_destroy(th_model->request_queue);
+
+ while (ff_queue_size(th_model->lltask_queue) != 0) {
+ LastLevelTaskItem *item = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue);
+ av_freep(&item);
+ }
+ ff_queue_destroy(th_model->lltask_queue);
+
+ while (ff_queue_size(th_model->task_queue) != 0) {
+ TaskItem *item = (TaskItem *)ff_queue_pop_front(th_model->task_queue);
+ av_frame_free(&item->in_frame);
+ av_frame_free(&item->out_frame);
+ av_freep(&item);
+ }
+ }
+ av_freep(&th_model);
+ av_freep(model);
+}
diff --git a/libavfilter/dnn/dnn_backend_torch.h b/libavfilter/dnn/dnn_backend_torch.h
new file mode 100644
index 0000000000..5d6a08f85f
--- /dev/null
+++ b/libavfilter/dnn/dnn_backend_torch.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022
+ *
+ * This file is part of FFmpeg.
+ *
+ * FFmpeg is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * FFmpeg is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with FFmpeg; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+/**
+ * @file
+ * DNN inference functions interface for Torch backend.
+ */
+
+#ifndef AVFILTER_DNN_DNN_BACKEND_TORCH_H
+#define AVFILTER_DNN_DNN_BACKEND_TORCH_H
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include "../dnn_interface.h"
+
+DNNModel *ff_dnn_load_model_th(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
+
+int ff_dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params);
+DNNAsyncStatusType ff_dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out);
+int ff_dnn_flush_th(const DNNModel *model);
+
+void ff_dnn_free_model_th(DNNModel **model);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c
index 554a36b0dc..6f4e02b481 100644
--- a/libavfilter/dnn/dnn_interface.c
+++ b/libavfilter/dnn/dnn_interface.c
@@ -27,6 +27,7 @@
#include "dnn_backend_native.h"
#include "dnn_backend_tf.h"
#include "dnn_backend_openvino.h"
+#include "dnn_backend_torch.h"
#include "libavutil/mem.h"
DNNModule *ff_get_dnn_module(DNNBackendType backend_type)
@@ -70,6 +71,17 @@ DNNModule *ff_get_dnn_module(DNNBackendType backend_type)
return NULL;
#endif
break;
+ case DNN_TH:
+ #if (CONFIG_LIBTORCH == 1)
+ dnn_module->load_model = &ff_dnn_load_model_th;
+ dnn_module->execute_model = &ff_dnn_execute_model_th;
+ dnn_module->get_result = &ff_dnn_get_result_th;
+ dnn_module->flush = &ff_dnn_flush_th;
+ dnn_module->free_model = &ff_dnn_free_model_th;
+ #else
+ av_freep(&dnn_module);
+ #endif
+ break;
default:
av_log(NULL, AV_LOG_ERROR, "Module backend_type is not native or tensorflow\n");
av_freep(&dnn_module);
diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c
index 532b089002..cbaa1e601f 100644
--- a/libavfilter/dnn/dnn_io_proc.c
+++ b/libavfilter/dnn/dnn_io_proc.c
@@ -24,10 +24,20 @@
#include "libavutil/avassert.h"
#include "libavutil/detection_bbox.h"
+static enum AVPixelFormat get_pixel_format(DNNData *data);
+
int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
{
struct SwsContext *sws_ctx;
+ int frame_size = frame->height * frame->width;
+ int linesize[3];
+ void **dst_data, *middle_data;
+ enum AVPixelFormat fmt;
int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
+ linesize[0] = frame->linesize[0];
+ dst_data = (void **)frame->data;
+ fmt = get_pixel_format(output);
+
if (bytewidth < 0) {
return AVERROR(EINVAL);
}
@@ -35,6 +45,18 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
avpriv_report_missing_feature(log_ctx, "data type rather than DNN_FLOAT");
return AVERROR(ENOSYS);
}
+ if (fmt == AV_PIX_FMT_GBRP) {
+ middle_data = malloc(frame_size * 3 * sizeof(uint8_t));
+ if (!middle_data) {
+ av_log(log_ctx, AV_LOG_ERROR, "Failed to malloc memory for middle_data for "
+ "the conversion fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
+ av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32), frame->width, frame->height,
+ av_get_pix_fmt_name(AV_PIX_FMT_GRAY8),frame->width, frame->height);
+ return AVERROR(EINVAL);
+ }
+ dst_data = &middle_data;
+ linesize[0] = frame->width * 3;
+ }
switch (frame->format) {
case AV_PIX_FMT_RGB24:
@@ -51,12 +73,43 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32), frame->width * 3, frame->height,
av_get_pix_fmt_name(AV_PIX_FMT_GRAY8), frame->width * 3, frame->height);
+ av_freep(&middle_data);
return AVERROR(EINVAL);
}
sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0},
(const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, 0, frame->height,
- (uint8_t * const*)frame->data, frame->linesize);
+ (uint8_t * const*)dst_data, linesize);
sws_freeContext(sws_ctx);
+ switch (fmt) {
+ case AV_PIX_FMT_GBRP:
+ sws_ctx = sws_getContext(frame->width,
+ frame->height,
+ AV_PIX_FMT_GBRP,
+ frame->width,
+ frame->height,
+ frame->format,
+ 0, NULL, NULL, NULL);
+ if (!sws_ctx) {
+ av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion "
+ "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
+ av_get_pix_fmt_name(AV_PIX_FMT_GBRP), frame->width, frame->height,
+ av_get_pix_fmt_name(frame->format),frame->width, frame->height);
+ av_freep(&middle_data);
+ return AVERROR(EINVAL);
+ }
+ sws_scale(sws_ctx, (const uint8_t * const[4]){(uint8_t *)dst_data[0] + frame_size * sizeof(uint8_t),
+ (uint8_t *)dst_data[0] + frame_size * sizeof(uint8_t) * 2,
+ (uint8_t *)dst_data[0], 0},
+ (const int [4]){frame->width * sizeof(uint8_t),
+ frame->width * sizeof(uint8_t),
+ frame->width * sizeof(uint8_t), 0}
+ , 0, frame->height,
+ (uint8_t * const*)frame->data, frame->linesize);
+ break;
+ default:
+ break;
+ }
+ av_freep(&middle_data);
return 0;
case AV_PIX_FMT_GRAYF32:
av_image_copy_plane(frame->data[0], frame->linesize[0],
@@ -101,6 +154,14 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
{
struct SwsContext *sws_ctx;
int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
+ int frame_size = frame->height * frame->width;
+ int linesize[3];
+ void **src_data, *middle_data = NULL;
+ enum AVPixelFormat fmt;
+ linesize[0] = frame->linesize[0];
+ src_data = (void **)frame->data;
+ fmt = get_pixel_format(input);
+
if (bytewidth < 0) {
return AVERROR(EINVAL);
}
@@ -112,6 +173,46 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
switch (frame->format) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
+ switch (fmt) {
+ case AV_PIX_FMT_GBRP:
+ middle_data = av_malloc(frame_size * 3 * sizeof(uint8_t));
+ if (!middle_data) {
+ av_log(log_ctx, AV_LOG_ERROR, "Failed to malloc memory for middle_data for "
+ "the conversion fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
+ av_get_pix_fmt_name(frame->format), frame->width, frame->height,
+ av_get_pix_fmt_name(AV_PIX_FMT_GBRP),frame->width, frame->height);
+ return AVERROR(EINVAL);
+ }
+ sws_ctx = sws_getContext(frame->width,
+ frame->height,
+ frame->format,
+ frame->width,
+ frame->height,
+ AV_PIX_FMT_GBRP,
+ 0, NULL, NULL, NULL);
+ if (!sws_ctx) {
+ av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion "
+ "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
+ av_get_pix_fmt_name(frame->format), frame->width, frame->height,
+ av_get_pix_fmt_name(AV_PIX_FMT_GBRP),frame->width, frame->height);
+ av_freep(&middle_data);
+ return AVERROR(EINVAL);
+ }
+ sws_scale(sws_ctx, (const uint8_t **)frame->data,
+ frame->linesize, 0, frame->height,
+ (uint8_t * const [4]){(uint8_t *)middle_data + frame_size * sizeof(uint8_t),
+ (uint8_t *)middle_data + frame_size * sizeof(uint8_t) * 2,
+ (uint8_t *)middle_data, 0},
+ (const int [4]){frame->width * sizeof(uint8_t),
+ frame->width * sizeof(uint8_t),
+ frame->width * sizeof(uint8_t), 0});
+ sws_freeContext(sws_ctx);
+ src_data = &middle_data;
+ linesize[0] = frame->width * 3;
+ break;
+ default:
+ break;
+ }
sws_ctx = sws_getContext(frame->width * 3,
frame->height,
AV_PIX_FMT_GRAY8,
@@ -124,13 +225,15 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
av_get_pix_fmt_name(AV_PIX_FMT_GRAY8), frame->width * 3, frame->height,
av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32),frame->width * 3, frame->height);
+ av_freep(&middle_data);
return AVERROR(EINVAL);
}
- sws_scale(sws_ctx, (const uint8_t **)frame->data,
- frame->linesize, 0, frame->height,
+ sws_scale(sws_ctx, (const uint8_t **)src_data,
+ linesize, 0, frame->height,
(uint8_t * const [4]){input->data, 0, 0, 0},
(const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0});
sws_freeContext(sws_ctx);
+ av_freep(&middle_data);
break;
case AV_PIX_FMT_GRAYF32:
av_image_copy_plane(input->data, bytewidth,
@@ -184,6 +287,14 @@ static enum AVPixelFormat get_pixel_format(DNNData *data)
av_assert0(!"unsupported data pixel format.\n");
return AV_PIX_FMT_BGR24;
}
+ } else if (data->dt == DNN_FLOAT) {
+ switch (data->order) {
+ case DCO_RGB_PLANAR:
+ return AV_PIX_FMT_GBRP;
+ default:
+ av_assert0(!"unsupported data pixel format.\n");
+ return AV_PIX_FMT_GBRP;
+ }
}
av_assert0(!"unsupported data type.\n");
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index 5083e3de19..a4e1147fb9 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -53,19 +53,31 @@ static char **separate_output_names(const char *expr, const char *val_sep, int *
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
{
+ DNNBackendType backend = ctx->backend_type;
+
if (!ctx->model_filename) {
av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n");
return AVERROR(EINVAL);
}
- if (!ctx->model_inputname) {
- av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
- return AVERROR(EINVAL);
- }
- ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs);
- if (!ctx->model_outputnames) {
- av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n");
- return AVERROR(EINVAL);
+ if (backend == DNN_TH) {
+ if (ctx->model_inputname)
+ av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require inputname, "\
+ "inputname will be ignored.\n");
+ if (ctx->model_outputnames)
+ av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require outputname(s), "\
+ "all outputname(s) will be ignored.\n");
+ ctx->nb_outputs = 1;
+ } else {
+ if (!ctx->model_inputname) {
+ av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
+ return AVERROR(EINVAL);
+ }
+ ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs);
+ if (!ctx->model_outputnames) {
+ av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n");
+ return AVERROR(EINVAL);
+ }
}
ctx->dnn_module = ff_get_dnn_module(ctx->backend_type);
@@ -113,8 +125,9 @@ int ff_dnn_get_input(DnnContext *ctx, DNNData *input)
int ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height)
{
+ const char *model_outputnames = ctx->backend_type == DNN_TH ? NULL : ctx->model_outputnames[0];
return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height,
- (const char *)ctx->model_outputnames[0], output_width, output_height);
+ model_outputnames, output_width, output_height);
}
int ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame)
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index d94baa90c4..32698f788b 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -32,7 +32,7 @@
#define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!')
-typedef enum {DNN_NATIVE, DNN_TF, DNN_OV} DNNBackendType;
+typedef enum {DNN_NATIVE, DNN_TF, DNN_OV, DNN_TH} DNNBackendType;
typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType;
@@ -40,6 +40,7 @@ typedef enum {
DCO_NONE,
DCO_BGR_PACKED,
DCO_RGB_PACKED,
+ DCO_RGB_PLANAR,
} DNNColorOrder;
typedef enum {
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index cac096a19f..ac1dc6e1d9 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -52,6 +52,9 @@ static const AVOption dnn_processing_options[] = {
#endif
#if (CONFIG_LIBOPENVINO == 1)
{ "openvino", "openvino backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 2 }, 0, 0, FLAGS, "backend" },
+#endif
+#if (CONFIG_LIBTORCH == 1)
+ { "torch", "torch backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 3 }, 0, 0, FLAGS, "backend" },
#endif
DNN_COMMON_OPTIONS
{ NULL }
--
2.17.1
More information about the ffmpeg-devel
mailing list