@@ -120,7 +120,7 @@ public MethodSpec generate(ServiceDefinition def) {
120
120
Packages .getPrefixedPackage (
121
121
def .getServiceName ().getPackage (), options .packagePrefix ()),
122
122
className .simpleName (),
123
- ErrorGenerationUtils .responseTypeName (endpoint .getEndpointName ())),
123
+ ErrorGenerationUtils .endpointResponseResultTypeName (endpoint .getEndpointName ())),
124
124
endpoint ,
125
125
endpoint .getEndpointName (),
126
126
endpoint .getReturns ())
@@ -178,16 +178,19 @@ private Optional<FieldSpec> serializer(EndpointName endpointName, Type type) {
178
178
private Optional <FieldSpec > deserializer (
179
179
TypeName responseType , EndpointDefinition endpointDef , EndpointName endpointName , Optional <Type > type ) {
180
180
TypeName className = Primitives .box (returnTypes .baseType (type ));
181
- if (isBinaryOrOptionalBinary (className , returnTypes ) && !options .generateDialogueEndpointErrorResultTypes ()) {
181
+ boolean generateResultTypes = shouldGenerateDialogueEndpointErrorResultTypesForEndpoint (options , endpointDef );
182
+
183
+ if (isBinaryOrOptionalBinary (className , returnTypes ) && !generateResultTypes ) {
182
184
return Optional .empty ();
183
185
}
184
- ParameterizedTypeName deserializerType =
185
- ParameterizedTypeName .get (ClassName .get (Deserializer .class ), responseType );
186
+
187
+ ParameterizedTypeName deserializerType = ParameterizedTypeName .get (
188
+ ClassName .get (Deserializer .class ), generateResultTypes ? responseType : className );
186
189
187
190
CodeBlock initializer = CodeBlock .of (
188
191
"$L.bodySerDe().$L" ,
189
192
StaticFactoryMethodGenerator .RUNTIME ,
190
- options . generateDialogueEndpointErrorResultTypes ()
193
+ generateResultTypes
191
194
? constructDeserializerWithEndpointErrors (endpointDef , className , responseType )
192
195
: constructDeserializer (type , className ));
193
196
@@ -208,16 +211,14 @@ private CodeBlock constructDeserializerWithEndpointErrors(
208
211
CodeBlock .Builder deserializerArgsBuilder = CodeBlock .builder ()
209
212
.add ("$T.<$T>builder()" , DeserializerArgs .class , responseType )
210
213
.add (".baseType(new $T<>() {})" , TypeMarker .class )
211
- // TODO(pm): consider making "Success" a constant string for re-use in the record creation.
212
214
.add (".success(new $T<$T.Success>() {})" , TypeMarker .class , responseType );
213
215
for (EndpointError err : endpointDef .getErrors ()) {
214
216
ErrorTypeName errorTypeName = err .getError ();
215
217
String errorName = errorTypeName .getName ();
216
- String errorType = CaseFormat .UPPER_CAMEL .to (CaseFormat .UPPER_UNDERSCORE , errorName );
217
218
ClassName errorClass = ClassName .get (
218
219
errorTypeName .getPackage (),
219
220
ErrorGenerationUtils .errorTypesClassName (errorTypeName .getNamespace ()),
220
- errorType );
221
+ CaseFormat . UPPER_CAMEL . to ( CaseFormat . UPPER_UNDERSCORE , errorName ) );
221
222
deserializerArgsBuilder .add (
222
223
".error($T.name(), new $T<$T.$L>() {})" , errorClass , TypeMarker .class , responseType , errorName );
223
224
}
@@ -239,6 +240,12 @@ private static boolean isBinaryOrOptionalBinary(TypeName className, ReturnTypeMa
239
240
return isBinary (className , returnTypes ) || isOptionalBinary (className , returnTypes );
240
241
}
241
242
243
+ private static boolean shouldGenerateDialogueEndpointErrorResultTypesForEndpoint (
244
+ Options options , EndpointDefinition endpointDefinition ) {
245
+ return options .generateDialogueEndpointErrorResultTypes ()
246
+ && !endpointDefinition .getErrors ().isEmpty ();
247
+ }
248
+
242
249
private static boolean isBinary (TypeName className , ReturnTypeMapper returnTypes ) {
243
250
return className .equals (returnTypes .baseType (Type .primitive (PrimitiveType .BINARY )));
244
251
}
@@ -280,13 +287,14 @@ private MethodSpec clientImpl(ClassName className, EndpointDefinition def) {
280
287
.build ();
281
288
String codeBlock = methodType .switchBy (
282
289
"$L.clients().callBlocking($L, $L.build(), $L);" , "$L.clients().call($L, $L.build" + "(), $L);" );
290
+ boolean generateResultTypes = shouldGenerateDialogueEndpointErrorResultTypesForEndpoint (options , def );
283
291
CodeBlock execute = CodeBlock .of (
284
292
codeBlock ,
285
293
StaticFactoryMethodGenerator .RUNTIME ,
286
294
Names .endpointChannel (def ),
287
295
REQUEST ,
288
296
def .getReturns ()
289
- .filter (type -> !options . generateDialogueEndpointErrorResultTypes ()
297
+ .filter (type -> !generateResultTypes
290
298
&& isBinaryOrOptionalBinary (returnTypes .baseType (type ), returnTypes ))
291
299
.map (type -> StaticFactoryMethodGenerator .RUNTIME
292
300
+ (isOptionalBinary (returnTypes .baseType (type ), returnTypes )
@@ -295,27 +303,22 @@ && isBinaryOrOptionalBinary(returnTypes.baseType(type), returnTypes))
295
303
.orElseGet (() -> def .getEndpointName ().get () + "Deserializer" ));
296
304
297
305
methodBuilder .addCode (request );
298
- methodBuilder .addCode (methodType .switchBy (
299
- def .getReturns ().isPresent ()
300
- || (options .generateDialogueEndpointErrorResultTypes ()
301
- && !def .getErrors ().isEmpty ())
302
- ? "return "
303
- : "" ,
304
- "return " ));
306
+ methodBuilder .addCode (
307
+ methodType .switchBy (def .getReturns ().isPresent () || generateResultTypes ? "return " : "" , "return " ));
305
308
methodBuilder .addCode (execute );
306
309
307
310
return methodBuilder .build ();
308
311
}
309
312
310
313
private TypeName getReturnType (EndpointDefinition def , ClassName className ) {
311
- if (options .generateDialogueEndpointErrorResultTypes ()
312
- && !def .getErrors ().isEmpty ()) {
313
- ClassName returnType = ClassName .get (
314
+ if (shouldGenerateDialogueEndpointErrorResultTypesForEndpoint (options , def )) {
315
+ ClassName responseResultTypeName = ClassName .get (
314
316
className .packageName (),
315
317
className .simpleName (),
316
- ErrorGenerationUtils .responseTypeName (def .getEndpointName ()));
318
+ ErrorGenerationUtils .endpointResponseResultTypeName (def .getEndpointName ()));
317
319
return methodType .switchBy (
318
- returnType , ParameterizedTypeName .get (ClassName .get (ListenableFuture .class ), returnType ));
320
+ responseResultTypeName ,
321
+ ParameterizedTypeName .get (ClassName .get (ListenableFuture .class ), responseResultTypeName ));
319
322
}
320
323
return methodType .switchBy (returnTypes .baseType (def .getReturns ()), returnTypes .async (def .getReturns ()));
321
324
}
0 commit comments