Skip to content

Commit

Permalink
Revert "Increased compile time max GPUs to 512. Switched to int16_t D…
Browse files Browse the repository at this point in the history
…eviceIndex. (pytorch#119639)"

This reverts commit 7c55642.

Reverted pytorch#119639 on behalf of https://github.com/kit1980 due to breaking internal builds, see D54286923 ([comment](pytorch#119639 (comment)))
  • Loading branch information
pytorchmergebot committed Feb 28, 2024
1 parent 1c67f6c commit a9d9077
Show file tree
Hide file tree
Showing 23 changed files with 99 additions and 140 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/core/op_registration/infer_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ constexpr int checkStaticTypes() {
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(std::conjunction<
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int16_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int16_t, int64_t and bool are supported as an integral argument type");
bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
>::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
static_assert(std::conjunction<
bool_t<!std::is_same<Types, float>::value>...
>::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cuda/jiterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
ss << extra_args_types;
ss << vec_size;
ss << dev_idx;
// DeviceIndex, e.g. int8_t, is not treated as a number by the stream, cast to int as a workaround
ss << static_cast<int>(dev_idx);
const std::string cache_key = ss.str();

static std::mutex _jiterator_mutex;
Expand Down
15 changes: 4 additions & 11 deletions aten/src/ATen/native/ForeachUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,10 @@ using IndicesT = std::vector<size_t>;
using nested_optional_tensorvec_t =
std::vector<std::vector<c10::optional<at::Tensor>>>;
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;

// Warning: Do not use ParamsHash for keys with potentially uninitialized
// padding bytes!
struct _DeviceDtypeHasher {
std::size_t operator()(const DeviceDtypeKey& k) const noexcept {
return std::hash<at::Device>{}(k.first) ^
std::hash<at::ScalarType>{}(k.second);
}
};
using FlatMap =
std::unordered_map<DeviceDtypeKey, TensorsAndIndicesT, _DeviceDtypeHasher>;
using FlatMap = std::unordered_map<
DeviceDtypeKey,
TensorsAndIndicesT,
ParamsHash<DeviceDtypeKey>>;

inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
const nested_optional_tensorvec_t& nested_tensorlist,
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/utils/ParamsHash.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ namespace at::native {
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
// WARNING: This hash function will produce unexpected results for `Params` with uninitialized padding values, as the
// padding is also part of the hash. Use with caution.
template <typename Params>
struct ParamsHash {
// Params must be a POD because we read out its memory
Expand Down
30 changes: 10 additions & 20 deletions c10/core/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,19 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {

TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");

if (!device_index_str.empty()) {
// If the user passed an index in the device string, check if it is a valid
// int between 0 and c10::Device::MAX_NUM_DEVICES - 1 inclusively
int full_index = -1;
try {
full_index = std::stoi(device_index_str);
} catch (const std::exception&) {
TORCH_CHECK(
false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
try {
if (!device_index_str.empty()) {
index_ = static_cast<c10::DeviceIndex>(std::stoi(device_index_str));
}
} catch (const std::exception&) {
TORCH_CHECK(
0 <= full_index && full_index < c10::Device::MAX_NUM_DEVICES,
"Device index must be between 0 and ",
c10::Device::MAX_NUM_DEVICES - 1,
" inclusively.");
index_ = static_cast<c10::DeviceIndex>(full_index);
false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
}

type_ = parse_type(device_name);
validate();
}
Expand Down
25 changes: 5 additions & 20 deletions c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace c10 {
/// A DeviceIndex is not independently meaningful without knowing
/// the DeviceType it is associated; try to use Device rather than
/// DeviceIndex directly.
using DeviceIndex = int16_t;
using DeviceIndex = int8_t;

/// Represents a compute device on which a tensor is located. A device is
/// uniquely identified by a type, which specifies the type of machine it is
Expand All @@ -29,18 +29,6 @@ using DeviceIndex = int16_t;
/// represents a specific, concrete device,
/// 2. When the device type is CPU, the device index must be zero.
struct C10_API Device final {
/// The maximum number of devices that we recognize (formerly known as
/// C10_COMPILE_TIME_MAX_GPUS). This value cannot be more than 32767 because
/// our DeviceIndex is a int16_t. Note that this does not include the default
/// device index -1, but instead defines the range from 0 to MAX_NUM_DEVICES-1
/// inclusively.
#ifdef FBCODE_CAFFE2
// fbcode depends on this value being 16
static constexpr DeviceIndex MAX_NUM_DEVICES = 16;
#else
static constexpr DeviceIndex MAX_NUM_DEVICES = 512;
#endif

using Type = DeviceType;

/// Constructs a new `Device` from a `DeviceType` and an optional device
Expand Down Expand Up @@ -72,7 +60,6 @@ struct C10_API Device final {
/// Sets the device index.
void set_index(DeviceIndex index) {
index_ = index;
validate();
}

/// Returns the type of device this is.
Expand Down Expand Up @@ -188,10 +175,8 @@ struct C10_API Device final {
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ >= -1 && index_ < MAX_NUM_DEVICES,
"Device index must be between -1 and ",
MAX_NUM_DEVICES - 1,
" inclusively, got ",
index_ >= -1,
"Device index must be -1 or non-negative, got ",
static_cast<int>(index_));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
Expand All @@ -211,7 +196,7 @@ struct hash<c10::Device> {
// Are you here because this static assert failed? Make sure you ensure
// that the bitmasking code below is updated accordingly!
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
static_assert(sizeof(c10::DeviceIndex) == 2, "DeviceIndex is not 16-bit");
static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
// Note [Hazard when concatenating signed integers]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// We must first convert to a same-sized unsigned type, before promoting to
Expand All @@ -224,7 +209,7 @@ struct hash<c10::Device> {
// sake.
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 |
static_cast<uint32_t>(static_cast<uint16_t>(d.index()));
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits);
}
};
Expand Down
6 changes: 3 additions & 3 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3169,7 +3169,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
#if UINTPTR_MAX == 0xFFFFFFFF
// This is a 32-bit system
static constexpr bool check_sizes() {
constexpr size_t tsize = 21 * sizeof(int64_t);
constexpr size_t tsize = 20 * sizeof(int64_t);

// clang-format off
are_equal<sizeof(storage_), 4, FieldNameEnum::storage_>();
Expand All @@ -3181,7 +3181,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
are_equal<sizeof(device_opt_), 6, FieldNameEnum::device_opt_>();
are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
// clang-format on
Expand All @@ -3206,7 +3206,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
are_equal<sizeof(device_opt_), 6, FieldNameEnum::device_opt_>();
are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
// clang-format on
Expand Down
13 changes: 8 additions & 5 deletions c10/cuda/CUDAFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ DeviceIndex device_count() noexcept {
try {
auto result = device_count_impl(/*fail_if_no_driver=*/false);
TORCH_INTERNAL_ASSERT(
result <= c10::Device::MAX_NUM_DEVICES,
result <= std::numeric_limits<DeviceIndex>::max(),
"Too many CUDA devices, DeviceIndex overflowed");
return result;
} catch (const c10::Error& ex) {
Expand All @@ -118,7 +118,7 @@ DeviceIndex device_count_ensure_non_zero() {
// Zero gpus doesn't produce a warning in `device_count` but we fail here
TORCH_CHECK(count, "No CUDA GPUs are available");
TORCH_INTERNAL_ASSERT(
count <= c10::Device::MAX_NUM_DEVICES,
count <= std::numeric_limits<DeviceIndex>::max(),
"Too many CUDA devices, DeviceIndex overflowed");
return static_cast<DeviceIndex>(count);
}
Expand Down Expand Up @@ -219,7 +219,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
auto err = cudaGetDevice(&tmp_device);
if (err == cudaSuccess) {
TORCH_INTERNAL_ASSERT(
tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES,
tmp_device >= 0 &&
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
"cudaGetDevice returns invalid device ",
tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
Expand Down Expand Up @@ -269,7 +270,8 @@ DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
int tmp_cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device));
TORCH_INTERNAL_ASSERT(
tmp_cur_device >= 0 && tmp_cur_device < c10::Device::MAX_NUM_DEVICES,
tmp_cur_device >= 0 &&
tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(),
"cudaGetDevice returns invalid device ",
tmp_cur_device);
auto cur_device = static_cast<DeviceIndex>(tmp_cur_device);
Expand All @@ -295,7 +297,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
auto err = cudaGetDevice(&tmp_device);
if (err == cudaSuccess) {
TORCH_INTERNAL_ASSERT(
tmp_device >= 0 && tmp_device < c10::Device::MAX_NUM_DEVICES,
tmp_device >= 0 &&
tmp_device <= std::numeric_limits<DeviceIndex>::max(),
"cudaGetDevice returns invalid device ",
tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
Expand Down
12 changes: 12 additions & 0 deletions c10/cuda/CUDAMacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,15 @@
#else
#define C10_CUDA_API C10_CUDA_IMPORT
#endif

/**
* The maximum number of GPUs that we recognizes. Increasing this beyond the
* initial limit of 16 broke Caffe2 testing, hence the ifdef guards.
* This value cannot be more than 255 because our DeviceIndex is a uint8_t.
o */
#ifdef FBCODE_CAFFE2
// fbcode depends on this value being 16
#define C10_COMPILE_TIME_MAX_GPUS 16
#else
#define C10_COMPILE_TIME_MAX_GPUS 64
#endif
12 changes: 6 additions & 6 deletions c10/cuda/CUDAStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ static int max_stream_priorities;
// the destruction.
#if !defined(USE_ROCM)
// CUDA-only: used to initializes the stream pools (once)
static c10::once_flag device_flags[c10::Device::MAX_NUM_DEVICES];
static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
#endif
static std::atomic<uint32_t>
priority_counters[c10::cuda::max_compile_time_stream_priorities]
[c10::Device::MAX_NUM_DEVICES];
[C10_COMPILE_TIME_MAX_GPUS];

static cudaStream_t streams[c10::cuda::max_compile_time_stream_priorities]
[c10::Device::MAX_NUM_DEVICES][kStreamsPerPool];
[C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
#ifdef USE_ROCM
static c10::once_flag
stream_flags[c10::cuda::max_compile_time_stream_priorities]
[c10::Device::MAX_NUM_DEVICES][kStreamsPerPool];
[C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
#endif

// Note [HIP Lazy Streams]
Expand Down Expand Up @@ -168,10 +168,10 @@ static void initGlobalStreamState() {
// Check if the number of GPUs matches the expected compile-time max number
// of GPUs.
TORCH_CHECK(
num_gpus <= c10::Device::MAX_NUM_DEVICES,
num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
"Number of CUDA devices on the machine is larger than the compiled "
"max number of gpus expected (",
c10::Device::MAX_NUM_DEVICES,
C10_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile.");
int leastPriority = -1, greatestPriority = -1;
C10_CUDA_CHECK(
Expand Down
18 changes: 9 additions & 9 deletions caffe2/contrib/nccl/cuda_nccl_op_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(

REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
OPERATOR_SCHEMA(NCCLAllreduce)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.CostInferenceFunction(NCCLAllreduceOp::CostInference)
.TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
.IdenticalTypeAndShape()
Expand All @@ -236,8 +236,8 @@ SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);

REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
OPERATOR_SCHEMA(NCCLBroadcast)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.IdenticalTypeAndShape()
.InputsCanCrossDevices()
.EnforceOneToOneInplace()
Expand All @@ -247,7 +247,7 @@ SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);

REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
OPERATOR_SCHEMA(NCCLReduce)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1)
.IdenticalTypeAndShapeOfInput(0)
.InputsCanCrossDevices()
Expand All @@ -257,16 +257,16 @@ SHOULD_NOT_DO_GRADIENT(NCCLReduce);

REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
OPERATOR_SCHEMA(NCCLAllGather)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.InputsCanCrossDevices()
.DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLAllGather);

REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
OPERATOR_SCHEMA(NCCLReduceScatter)
.NumInputs(1, c10::Device::MAX_NUM_DEVICES)
.NumOutputs(1, c10::Device::MAX_NUM_DEVICES)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.InputsCanCrossDevices()
.DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);
Expand Down
8 changes: 4 additions & 4 deletions caffe2/core/context_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ static std::unordered_map<void*, uint8_t> g_cuda_device_affiliation;
// Data structures for optional memory tracking. Access to these structures
// is guarded by the CUDAContext::mutex.
static std::unordered_map<void*, long> g_size_map;
static std::vector<long> g_total_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0);
static std::vector<long> g_max_by_gpu_map(c10::Device::MAX_NUM_DEVICES, 0);
static std::vector<long> g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
static std::vector<long> g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);

static long g_total_mem = 0;
static long g_last_rep = 0;
Expand Down Expand Up @@ -208,10 +208,10 @@ static void Caffe2InitializeCuda() {
// of GPUs.
CAFFE_ENFORCE_LE(
NumCudaDevices(),
c10::Device::MAX_NUM_DEVICES,
C10_COMPILE_TIME_MAX_GPUS,
"Number of CUDA devices on the machine is larger than the compiled "
"max number of gpus expected (",
c10::Device::MAX_NUM_DEVICES,
C10_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile.");

for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {
Expand Down
4 changes: 2 additions & 2 deletions caffe2/core/context_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {

private:
ThreadLocalCUDAObjects() {
for (DeviceIndex i = 0; i < c10::Device::MAX_NUM_DEVICES; ++i) {
for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) {
cuda_streams_[i] = vector<c10::cuda::CUDAStream>();
}
}
Expand Down Expand Up @@ -164,7 +164,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
// WARNING: mapping from logical stream ID to c10::cuda::CUDAStream
// is NOT bijective; multiple logical stream IDs may map to the
// same underlying stream ID.
vector<c10::cuda::CUDAStream> cuda_streams_[c10::Device::MAX_NUM_DEVICES];
vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS];
std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_;
#ifdef CAFFE2_USE_CUDNN
std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_;
Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/cudnn_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class CuDNNWrapper {

using PerGPUCuDNNStates = std::array<
std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
c10::Device::MAX_NUM_DEVICES>;
C10_COMPILE_TIME_MAX_GPUS>;
static PerGPUCuDNNStates& cudnn_states();

C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);
Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/hip/miopen_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class MIOPENWrapper

using PerGPUMIOPENStates = std::array<
std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
c10::Device::MAX_NUM_DEVICES>;
C10_COMPILE_TIME_MAX_GPUS>;
static PerGPUMIOPENStates& miopen_states();

C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper);
Expand Down
Loading

0 comments on commit a9d9077

Please sign in to comment.