@@ -935,7 +935,8 @@ CreateFileHandler(MemoryBuffer &FirstInput,
935
935
" '" + FilesType + " ': invalid file type specified" );
936
936
}
937
937
938
- OffloadBundlerConfig::OffloadBundlerConfig () {
938
+ OffloadBundlerConfig::OffloadBundlerConfig ()
939
+ : CompressedBundleVersion(CompressedOffloadBundle::DefaultVersion) {
939
940
if (llvm::compression::zstd::isAvailable ()) {
940
941
CompressionFormat = llvm::compression::Format::Zstd;
941
942
// Compression level 3 is usually sufficient for zstd since long distance
@@ -951,16 +952,13 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
951
952
llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_IGNORE_ENV_VAR" );
952
953
if (IgnoreEnvVarOpt.has_value () && IgnoreEnvVarOpt.value () == " 1" )
953
954
return ;
954
-
955
955
auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_VERBOSE" );
956
956
if (VerboseEnvVarOpt.has_value ())
957
957
Verbose = VerboseEnvVarOpt.value () == " 1" ;
958
-
959
958
auto CompressEnvVarOpt =
960
959
llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESS" );
961
960
if (CompressEnvVarOpt.has_value ())
962
961
Compress = CompressEnvVarOpt.value () == " 1" ;
963
-
964
962
auto CompressionLevelEnvVarOpt =
965
963
llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESSION_LEVEL" );
966
964
if (CompressionLevelEnvVarOpt.has_value ()) {
@@ -973,6 +971,26 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
973
971
<< " Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
974
972
<< CompressionLevelStr.str () << " . Ignoring it.\n " ;
975
973
}
974
+ auto CompressedBundleFormatVersionOpt =
975
+ llvm::sys::Process::GetEnv (" COMPRESSED_BUNDLE_FORMAT_VERSION" );
976
+ if (CompressedBundleFormatVersionOpt.has_value ()) {
977
+ llvm::StringRef VersionStr = CompressedBundleFormatVersionOpt.value ();
978
+ uint16_t Version;
979
+ if (!VersionStr.getAsInteger (10 , Version)) {
980
+ if (Version >= 2 && Version <= 3 )
981
+ CompressedBundleVersion = Version;
982
+ else
983
+ llvm::errs ()
984
+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
985
+ << VersionStr.str ()
986
+ << " . Valid values are 2 or 3. Using default version "
987
+ << CompressedBundleVersion << " .\n " ;
988
+ } else
989
+ llvm::errs ()
990
+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
991
+ << VersionStr.str () << " . Using default version "
992
+ << CompressedBundleVersion << " .\n " ;
993
+ }
976
994
}
977
995
978
996
// Utility function to format numbers with commas
@@ -989,12 +1007,11 @@ static std::string formatWithCommas(unsigned long long Value) {
989
1007
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
990
1008
CompressedOffloadBundle::compress (llvm::compression::Params P,
991
1009
const llvm::MemoryBuffer &Input,
992
- bool Verbose) {
1010
+ uint16_t Version, bool Verbose) {
993
1011
if (!llvm::compression::zstd::isAvailable () &&
994
1012
!llvm::compression::zlib::isAvailable ())
995
1013
return createStringError (llvm::inconvertibleErrorCode (),
996
1014
" Compression not supported" );
997
-
998
1015
llvm::Timer HashTimer (" Hash Calculation Timer" , " Hash calculation time" ,
999
1016
*ClangOffloadBundlerTimerGroup);
1000
1017
if (Verbose)
@@ -1011,7 +1028,6 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1011
1028
auto BufferUint8 = llvm::ArrayRef<uint8_t >(
1012
1029
reinterpret_cast <const uint8_t *>(Input.getBuffer ().data ()),
1013
1030
Input.getBuffer ().size ());
1014
-
1015
1031
llvm::Timer CompressTimer (" Compression Timer" , " Compression time" ,
1016
1032
*ClangOffloadBundlerTimerGroup);
1017
1033
if (Verbose)
@@ -1021,22 +1037,54 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1021
1037
CompressTimer.stopTimer ();
1022
1038
1023
1039
uint16_t CompressionMethod = static_cast <uint16_t >(P.format );
1024
- uint32_t UncompressedSize = Input.getBuffer ().size ();
1025
- uint32_t TotalFileSize = MagicNumber.size () + sizeof (TotalFileSize) +
1026
- sizeof (Version) + sizeof (CompressionMethod) +
1027
- sizeof (UncompressedSize) + sizeof (TruncatedHash) +
1028
- CompressedBuffer.size ();
1040
+
1041
+ // Store sizes in 64-bit variables first
1042
+ uint64_t UncompressedSize64 = Input.getBuffer ().size ();
1043
+ uint64_t TotalFileSize64;
1044
+
1045
+ // Calculate total file size based on version
1046
+ if (Version == 2 ) {
1047
+ // For V2, ensure the sizes don't exceed 32-bit limit
1048
+ if (UncompressedSize64 > std::numeric_limits<uint32_t >::max ())
1049
+ return createStringError (llvm::inconvertibleErrorCode (),
1050
+ " Uncompressed size exceeds version 2 limit" );
1051
+ if ((MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1052
+ sizeof (CompressionMethod) + sizeof (uint32_t ) + sizeof (TruncatedHash) +
1053
+ CompressedBuffer.size ()) > std::numeric_limits<uint32_t >::max ())
1054
+ return createStringError (llvm::inconvertibleErrorCode (),
1055
+ " Total file size exceeds version 2 limit" );
1056
+
1057
+ TotalFileSize64 = MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1058
+ sizeof (CompressionMethod) + sizeof (uint32_t ) +
1059
+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1060
+ } else { // Version 3
1061
+ TotalFileSize64 = MagicNumber.size () + sizeof (uint64_t ) + sizeof (Version) +
1062
+ sizeof (CompressionMethod) + sizeof (uint64_t ) +
1063
+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1064
+ }
1029
1065
1030
1066
SmallVector<char , 0 > FinalBuffer;
1031
1067
llvm::raw_svector_ostream OS (FinalBuffer);
1032
1068
OS << MagicNumber;
1033
1069
OS.write (reinterpret_cast <const char *>(&Version), sizeof (Version));
1034
1070
OS.write (reinterpret_cast <const char *>(&CompressionMethod),
1035
1071
sizeof (CompressionMethod));
1036
- OS.write (reinterpret_cast <const char *>(&TotalFileSize),
1037
- sizeof (TotalFileSize));
1038
- OS.write (reinterpret_cast <const char *>(&UncompressedSize),
1039
- sizeof (UncompressedSize));
1072
+
1073
+ // Write size fields according to version
1074
+ if (Version == 2 ) {
1075
+ uint32_t TotalFileSize32 = static_cast <uint32_t >(TotalFileSize64);
1076
+ uint32_t UncompressedSize32 = static_cast <uint32_t >(UncompressedSize64);
1077
+ OS.write (reinterpret_cast <const char *>(&TotalFileSize32),
1078
+ sizeof (TotalFileSize32));
1079
+ OS.write (reinterpret_cast <const char *>(&UncompressedSize32),
1080
+ sizeof (UncompressedSize32));
1081
+ } else { // Version 3
1082
+ OS.write (reinterpret_cast <const char *>(&TotalFileSize64),
1083
+ sizeof (TotalFileSize64));
1084
+ OS.write (reinterpret_cast <const char *>(&UncompressedSize64),
1085
+ sizeof (UncompressedSize64));
1086
+ }
1087
+
1040
1088
OS.write (reinterpret_cast <const char *>(&TruncatedHash),
1041
1089
sizeof (TruncatedHash));
1042
1090
OS.write (reinterpret_cast <const char *>(CompressedBuffer.data ()),
@@ -1046,18 +1094,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1046
1094
auto MethodUsed =
1047
1095
P.format == llvm::compression::Format::Zstd ? " zstd" : " zlib" ;
1048
1096
double CompressionRate =
1049
- static_cast <double >(UncompressedSize ) / CompressedBuffer.size ();
1097
+ static_cast <double >(UncompressedSize64 ) / CompressedBuffer.size ();
1050
1098
double CompressionTimeSeconds = CompressTimer.getTotalTime ().getWallTime ();
1051
1099
double CompressionSpeedMBs =
1052
- (UncompressedSize / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
1053
-
1100
+ (UncompressedSize64 / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
1054
1101
llvm::errs () << " Compressed bundle format version: " << Version << " \n "
1055
1102
<< " Total file size (including headers): "
1056
- << formatWithCommas (TotalFileSize ) << " bytes\n "
1103
+ << formatWithCommas (TotalFileSize64 ) << " bytes\n "
1057
1104
<< " Compression method used: " << MethodUsed << " \n "
1058
1105
<< " Compression level: " << P.level << " \n "
1059
1106
<< " Binary size before compression: "
1060
- << formatWithCommas (UncompressedSize ) << " bytes\n "
1107
+ << formatWithCommas (UncompressedSize64 ) << " bytes\n "
1061
1108
<< " Binary size after compression: "
1062
1109
<< formatWithCommas (CompressedBuffer.size ()) << " bytes\n "
1063
1110
<< " Compression rate: "
@@ -1069,16 +1116,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1069
1116
<< " Truncated MD5 hash: "
1070
1117
<< llvm::format_hex (TruncatedHash, 16 ) << " \n " ;
1071
1118
}
1119
+
1072
1120
return llvm::MemoryBuffer::getMemBufferCopy (
1073
1121
llvm::StringRef (FinalBuffer.data (), FinalBuffer.size ()));
1074
1122
}
1075
1123
1076
1124
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
1077
1125
CompressedOffloadBundle::decompress (const llvm::MemoryBuffer &Input,
1078
1126
bool Verbose) {
1079
-
1080
1127
StringRef Blob = Input.getBuffer ();
1081
1128
1129
+ // Check minimum header size (using V1 as it's the smallest)
1082
1130
if (Blob.size () < V1HeaderSize)
1083
1131
return llvm::MemoryBuffer::getMemBufferCopy (Blob);
1084
1132
@@ -1091,31 +1139,56 @@ CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
1091
1139
1092
1140
size_t CurrentOffset = MagicSize;
1093
1141
1142
+ // Read version
1094
1143
uint16_t ThisVersion;
1095
1144
memcpy (&ThisVersion, Blob.data () + CurrentOffset, sizeof (uint16_t ));
1096
1145
CurrentOffset += VersionFieldSize;
1097
1146
1147
+ // Verify header size based on version
1148
+ if (ThisVersion >= 2 && ThisVersion <= 3 ) {
1149
+ size_t RequiredSize = (ThisVersion == 2 ) ? V2HeaderSize : V3HeaderSize;
1150
+ if (Blob.size () < RequiredSize)
1151
+ return createStringError (inconvertibleErrorCode (),
1152
+ " Compressed bundle header size too small" );
1153
+ }
1154
+
1155
+ // Read compression method
1098
1156
uint16_t CompressionMethod;
1099
1157
memcpy (&CompressionMethod, Blob.data () + CurrentOffset, sizeof (uint16_t ));
1100
1158
CurrentOffset += MethodFieldSize;
1101
1159
1102
- uint32_t TotalFileSize;
1160
+ // Read total file size (version 2+)
1161
+ uint64_t TotalFileSize = 0 ;
1103
1162
if (ThisVersion >= 2 ) {
1104
- if (Blob.size () < V2HeaderSize)
1105
- return createStringError (inconvertibleErrorCode (),
1106
- " Compressed bundle header size too small" );
1107
- memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1108
- CurrentOffset += FileSizeFieldSize;
1163
+ if (ThisVersion == 2 ) {
1164
+ uint32_t TotalFileSize32;
1165
+ memcpy (&TotalFileSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1166
+ TotalFileSize = TotalFileSize32;
1167
+ CurrentOffset += FileSizeFieldSizeV2;
1168
+ } else { // Version 3
1169
+ memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1170
+ CurrentOffset += FileSizeFieldSizeV3;
1171
+ }
1109
1172
}
1110
1173
1111
- uint32_t UncompressedSize;
1112
- memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1113
- CurrentOffset += UncompressedSizeFieldSize;
1174
+ // Read uncompressed size
1175
+ uint64_t UncompressedSize = 0 ;
1176
+ if (ThisVersion <= 2 ) {
1177
+ uint32_t UncompressedSize32;
1178
+ memcpy (&UncompressedSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1179
+ UncompressedSize = UncompressedSize32;
1180
+ CurrentOffset += UncompressedSizeFieldSizeV2;
1181
+ } else { // Version 3
1182
+ memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1183
+ CurrentOffset += UncompressedSizeFieldSizeV3;
1184
+ }
1114
1185
1186
+ // Read hash
1115
1187
uint64_t StoredHash;
1116
1188
memcpy (&StoredHash, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1117
1189
CurrentOffset += HashFieldSize;
1118
1190
1191
+ // Determine compression format
1119
1192
llvm::compression::Format CompressionFormat;
1120
1193
if (CompressionMethod ==
1121
1194
static_cast <uint16_t >(llvm::compression::Format::Zlib))
@@ -1381,7 +1454,8 @@ Error OffloadBundler::BundleFiles() {
1381
1454
auto CompressionResult = CompressedOffloadBundle::compress (
1382
1455
{BundlerConfig.CompressionFormat , BundlerConfig.CompressionLevel ,
1383
1456
/* zstdEnableLdm=*/ true },
1384
- *BufferMemory, BundlerConfig.Verbose );
1457
+ *BufferMemory, BundlerConfig.CompressedBundleVersion ,
1458
+ BundlerConfig.Verbose );
1385
1459
if (auto Error = CompressionResult.takeError ())
1386
1460
return Error;
1387
1461
0 commit comments