|  | 
| 5 | 5 | 
 | 
| 6 | 6 | package software.amazon.smithy.rust.codegen.core.testutil | 
| 7 | 7 | 
 | 
|  | 8 | +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait | 
|  | 9 | +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait | 
|  | 10 | +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait | 
|  | 11 | +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait | 
| 8 | 12 | import software.amazon.smithy.build.PluginContext | 
| 9 | 13 | import software.amazon.smithy.model.Model | 
| 10 | 14 | import software.amazon.smithy.model.node.ObjectNode | 
| 11 | 15 | import software.amazon.smithy.model.node.ToNode | 
|  | 16 | +import software.amazon.smithy.model.shapes.ServiceShape | 
|  | 17 | +import software.amazon.smithy.model.shapes.ShapeId | 
|  | 18 | +import software.amazon.smithy.model.traits.AbstractTrait | 
|  | 19 | +import software.amazon.smithy.model.transform.ModelTransformer | 
|  | 20 | +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait | 
| 12 | 21 | import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig | 
| 13 | 22 | import software.amazon.smithy.rust.codegen.core.util.runCommand | 
| 14 | 23 | import java.io.File | 
| @@ -153,3 +162,128 @@ fun codegenIntegrationTest( | 
| 153 | 162 |     logger.fine(out.toString()) | 
| 154 | 163 |     return testDir | 
| 155 | 164 | } | 
|  | 165 | + | 
|  | 166 | +/** | 
|  | 167 | + * Metadata associated with a protocol that provides additional information needed for testing. | 
|  | 168 | + * | 
|  | 169 | + * @property protocol The protocol enum value this metadata is associated with | 
|  | 170 | + * @property contentType The HTTP Content-Type header value associated with this protocol. | 
|  | 171 | + */ | 
|  | 172 | +data class ProtocolMetadata( | 
|  | 173 | +    val protocol: ModelProtocol, | 
|  | 174 | +    val contentType: String, | 
|  | 175 | +) | 
|  | 176 | + | 
|  | 177 | +/** | 
|  | 178 | + * Represents the supported protocol traits in Smithy models. | 
|  | 179 | + * | 
|  | 180 | + * @property trait The Smithy trait instance with which the service shape must be annotated. | 
|  | 181 | + */ | 
|  | 182 | +enum class ModelProtocol(val trait: AbstractTrait) { | 
|  | 183 | +    AwsJson10(AwsJson1_0Trait.builder().build()), | 
|  | 184 | +    AwsJson11(AwsJson1_1Trait.builder().build()), | 
|  | 185 | +    RestJson(RestJson1Trait.builder().build()), | 
|  | 186 | +    RestXml(RestXmlTrait.builder().build()), | 
|  | 187 | +    Rpcv2Cbor(Rpcv2CborTrait.builder().build()), | 
|  | 188 | +    ; | 
|  | 189 | + | 
|  | 190 | +    // Create metadata after enum is initialized | 
|  | 191 | +    val metadata: ProtocolMetadata by lazy { | 
|  | 192 | +        when (this) { | 
|  | 193 | +            AwsJson10 -> ProtocolMetadata(this, "application/x-amz-json-1.0") | 
|  | 194 | +            AwsJson11 -> ProtocolMetadata(this, "application/x-amz-json-1.1") | 
|  | 195 | +            RestJson -> ProtocolMetadata(this, "application/json") | 
|  | 196 | +            RestXml -> ProtocolMetadata(this, "application/xml") | 
|  | 197 | +            Rpcv2Cbor -> ProtocolMetadata(this, "application/cbor") | 
|  | 198 | +        } | 
|  | 199 | +    } | 
|  | 200 | + | 
|  | 201 | +    companion object { | 
|  | 202 | +        private val TRAIT_IDS = values().map { it.trait.toShapeId() }.toSet() | 
|  | 203 | +        val ALL: Set<ModelProtocol> = values().toSet() | 
|  | 204 | + | 
|  | 205 | +        fun getTraitIds() = TRAIT_IDS | 
|  | 206 | +    } | 
|  | 207 | +} | 
|  | 208 | + | 
|  | 209 | +/** | 
|  | 210 | + * Removes all existing protocol traits annotated on the given service, | 
|  | 211 | + * then sets the provided `protocol` as the sole protocol trait for the service. | 
|  | 212 | + */ | 
|  | 213 | +fun Model.replaceProtocolTraitOnServerShapeId( | 
|  | 214 | +    serviceShapeId: ShapeId, | 
|  | 215 | +    modelProtocol: ModelProtocol, | 
|  | 216 | +): Model { | 
|  | 217 | +    val serviceShape = this.expectShape(serviceShapeId, ServiceShape::class.java) | 
|  | 218 | +    return replaceProtocolTraitOnServiceShape(serviceShape, modelProtocol) | 
|  | 219 | +} | 
|  | 220 | + | 
|  | 221 | +/** | 
|  | 222 | + * Removes all existing protocol traits annotated on the given service shape, | 
|  | 223 | + * then sets the provided `protocol` as the sole protocol trait for the service. | 
|  | 224 | + */ | 
|  | 225 | +fun Model.replaceProtocolTraitOnServiceShape( | 
|  | 226 | +    serviceShape: ServiceShape, | 
|  | 227 | +    modelProtocol: ModelProtocol, | 
|  | 228 | +): Model { | 
|  | 229 | +    val serviceBuilder = serviceShape.toBuilder() | 
|  | 230 | +    ModelProtocol.getTraitIds().forEach { traitId -> | 
|  | 231 | +        serviceBuilder.removeTrait(traitId) | 
|  | 232 | +    } | 
|  | 233 | +    val service = serviceBuilder.addTrait(modelProtocol.trait).build() | 
|  | 234 | +    return ModelTransformer.create().replaceShapes(this, listOf(service)) | 
|  | 235 | +} | 
|  | 236 | + | 
|  | 237 | +/** | 
|  | 238 | + * Processes a Smithy model string by applying different protocol traits and invoking the tests block on the model. | 
|  | 239 | + * For each protocol, this function: | 
|  | 240 | + *  1. Parses the Smithy model string | 
|  | 241 | + *  2. Replaces any existing protocol traits on service shapes with the specified protocol | 
|  | 242 | + *  3. Runs the provided test with the transformed model and protocol metadata | 
|  | 243 | + * | 
|  | 244 | + * @param protocolTraitIds Set of protocols to test against | 
|  | 245 | + * @param test Function that receives the transformed model and protocol metadata for testing | 
|  | 246 | + */ | 
|  | 247 | +fun String.forProtocols( | 
|  | 248 | +    protocolTraitIds: Set<ModelProtocol>, | 
|  | 249 | +    test: (Model, ProtocolMetadata) -> Unit, | 
|  | 250 | +) { | 
|  | 251 | +    val baseModel = this.asSmithyModel(smithyVersion = "2") | 
|  | 252 | +    val serviceShapes = baseModel.serviceShapes.toList() | 
|  | 253 | + | 
|  | 254 | +    protocolTraitIds.forEach { protocol -> | 
|  | 255 | +        val transformedModel = | 
|  | 256 | +            serviceShapes.fold(baseModel) { acc, shape -> | 
|  | 257 | +                acc.replaceProtocolTraitOnServiceShape(shape, protocol) | 
|  | 258 | +            } | 
|  | 259 | +        test(transformedModel, protocol.metadata) | 
|  | 260 | +    } | 
|  | 261 | +} | 
|  | 262 | + | 
|  | 263 | +/** | 
|  | 264 | + * Convenience overload that accepts vararg protocols instead of a Set. | 
|  | 265 | + * | 
|  | 266 | + * @param protocols Variable number of protocols to test against | 
|  | 267 | + * @param test Function that receives the transformed model and protocol metadata for testing | 
|  | 268 | + * @see forProtocols | 
|  | 269 | + */ | 
|  | 270 | +fun String.forProtocols( | 
|  | 271 | +    vararg protocols: ModelProtocol, | 
|  | 272 | +    test: (Model, ProtocolMetadata) -> Unit, | 
|  | 273 | +) { | 
|  | 274 | +    forProtocols(protocols.toSet(), test) | 
|  | 275 | +} | 
|  | 276 | + | 
|  | 277 | +/** | 
|  | 278 | + * Tests a Smithy model string against all supported protocols, with optional exclusions. | 
|  | 279 | + * | 
|  | 280 | + * @param exclude Set of protocols to exclude from testing (default is empty) | 
|  | 281 | + * @param test Function that receives the transformed model and protocol metadata for testing | 
|  | 282 | + * @see forProtocols | 
|  | 283 | + */ | 
|  | 284 | +fun String.forAllProtocols( | 
|  | 285 | +    exclude: Set<ModelProtocol> = emptySet(), | 
|  | 286 | +    test: (Model, ProtocolMetadata) -> Unit, | 
|  | 287 | +) { | 
|  | 288 | +    forProtocols(ModelProtocol.ALL - exclude, test) | 
|  | 289 | +} | 
0 commit comments