diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 0b195b695b..d2daf1c26d 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -185,12 +185,12 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { CHECK_EQ(indices.size(0), weights.size(0)); { auto wlmap = kv_store_.by(shard_id).wlock(); - + auto indices_data_ptr = indices.data_ptr(); for (auto index_iter = indexes.begin(); index_iter != indexes.end(); index_iter++) { const auto& id_index = *index_iter; - auto id = indices[id_index].template item(); + auto id = int64_t(indices_data_ptr[id_index]); wlmap->try_emplace( id, StoreValue(std::vector( @@ -255,7 +255,8 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { ") types mismatch"); auto row_storage_data_ptr = init_storage.template data_ptr(); - + auto wlmap = kv_store_.by(shard_id).wlock(); + auto indices_data_ptr = indices.data_ptr(); { for (auto index_iter = indexes.begin(); index_iter != indexes.end(); @@ -263,8 +264,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { const auto id_index = *index_iter; auto weights_data_ptr = weights.data_ptr(); - auto id = indices[id_index].template item(); - auto wlmap = kv_store_.by(shard_id).wlock(); + auto id = int64_t(indices_data_ptr[id_index]); const auto cached_iter = wlmap->find(id); if (cached_iter == wlmap->end()) { fill_from_row_storage(