[FFmpeg-devel] [PATCH] libavfi/dnn: enable LibTorch xpu device option support

wenbin.chen at intel.com wenbin.chen at intel.com
Mon Jun 3 08:09:35 EEST 2024


From: Wenbin Chen <wenbin.chen at intel.com>

Add xpu device support to libtorch backend.
To enable xpu support you need to add
 "-Wl,--no-as-needed -lintel-ext-pt-gpu -Wl,--as-needed" to
"--extra-libs" when configure ffmpeg.

Signed-off-by: Wenbin Chen <wenbin.chen at intel.com>
---
 libavfilter/dnn/dnn_backend_torch.cpp | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp
index 2557264713..ea493f5873 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -250,6 +250,10 @@ static int th_start_inference(void *args)
         av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
         return DNN_GENERIC_ERROR;
     }
+    // Transfer tensor to the same device as model
+    c10::Device device = (*th_model->jit_model->parameters().begin()).device();
+    if (infer_request->input_tensor->device() != device)
+        *infer_request->input_tensor = infer_request->input_tensor->to(device);
     inputs.push_back(*infer_request->input_tensor);
 
     *infer_request->output = th_model->jit_model->forward(inputs).toTensor();
@@ -285,6 +289,9 @@ static void infer_completion_callback(void *args) {
     switch (th_model->model.func_type) {
     case DFT_PROCESS_FRAME:
         if (task->do_ioproc) {
+            // Post process can only deal with CPU memory.
+            if (output->device() != torch::kCPU)
+                *output = output->to(torch::kCPU);
             outputs.scale = 255;
             outputs.data = output->data_ptr();
             if (th_model->model.frame_post_proc != NULL) {
@@ -424,7 +431,13 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
     th_model->ctx = ctx;
 
     c10::Device device = c10::Device(device_name);
-    if (!device.is_cpu()) {
+    if (device.is_xpu()) {
+        if (!at::hasXPU()) {
+            av_log(ctx, AV_LOG_ERROR, "No XPU device found\n");
+            goto fail;
+        }
+        at::detail::getXPUHooks().initXPU();
+    } else if (!device.is_cpu()) {
         av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name);
         goto fail;
     }
@@ -432,6 +445,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
     try {
         th_model->jit_model = new torch::jit::Module;
         (*th_model->jit_model) = torch::jit::load(ctx->model_filename);
+        th_model->jit_model->to(device);
     } catch (const c10::Error& e) {
         av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
         goto fail;
-- 
2.34.1



More information about the ffmpeg-devel mailing list