From 861974ad3499601bda7ce4513c0bc34607ab2e16 Mon Sep 17 00:00:00 2001
From: CaoE <e.cao@intel.com>
Date: Mon, 22 Aug 2022 13:57:49 +0800
Subject: [PATCH] do contiguous for q, k, and v to get better performance for
 normalize

---
 basicsr/models/archs/restormer_arch.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/basicsr/models/archs/restormer_arch.py b/basicsr/models/archs/restormer_arch.py
index a41221e..1874ba8 100644
--- a/basicsr/models/archs/restormer_arch.py
+++ b/basicsr/models/archs/restormer_arch.py
@@ -113,10 +113,10 @@ def forward(self, x):
 
         qkv = self.qkv_dwconv(self.qkv(x))
         q,k,v = qkv.chunk(3, dim=1)   
-        
-        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
-        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
-        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
+
+        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format)
+        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format)
+        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format)
 
         q = torch.nn.functional.normalize(q, dim=-1)
         k = torch.nn.functional.normalize(k, dim=-1)