Skip to content

Commit d15357d

Browse files
q10facebook-github-bot
authored andcommitted
Relax the checks for dimensions of pooled_embs (#4165)
Summary: Pull Request resolved: #4165 - Relax the checks for dimensions of pooled_embs Reviewed By: spcyppt Differential Revision: D75158424 fbshipit-source-id: 175bf88283d0a566bba7a97ada36194602e44b8e
1 parent e8284e2 commit d15357d

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ Tensor permute_pooled_embs_gpu_impl(
7070
}
7171

7272
TORCH_CHECK(
73-
pooled_embs.dim() == 2,
74-
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
73+
pooled_embs.dim() >= 2,
74+
"pooled_embs must be at least a 2-D tensor of size [B_local][Sum_T_global(D)], "
7575
"current shape is: ",
7676
torch_tensor_shape_str(pooled_embs));
7777

fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ Tensor permute_pooled_embs_cpu_impl(
2828
}
2929

3030
TORCH_CHECK(
31-
pooled_embs.dim() == 2,
32-
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
31+
pooled_embs.dim() >= 2,
32+
"pooled_embs must be at least a 2-D tensor of size [B_local][Sum_T_global(D)], "
3333
"current shape is: ",
3434
torch_tensor_shape_str(pooled_embs));
3535
TORCH_CHECK(

fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ Tensor permute_pooled_embs_split_gpu_impl(
7070
}
7171

7272
TORCH_CHECK(
73-
pooled_embs.dim() == 2,
74-
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
73+
pooled_embs.dim() >= 2,
74+
"pooled_embs must be at least a 2-D tensor of size [B_local][Sum_T_global(D)], "
7575
"current shape is: ",
7676
torch_tensor_shape_str(pooled_embs));
7777

fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ Tensor permute_pooled_embs_split_cpu_impl(
3737
}
3838

3939
TORCH_CHECK(
40-
pooled_embs.dim() == 2,
41-
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
40+
pooled_embs.dim() >= 2,
41+
"pooled_embs must be at least a 2-D tensor of size [B_local][Sum_T_global(D)], "
4242
"current shape is: ",
4343
torch_tensor_shape_str(pooled_embs));
4444
TORCH_CHECK(

0 commit comments

Comments
 (0)