4242import javax .persistence .metamodel .SingularAttribute ;
4343import javax .persistence .metamodel .Type ;
4444
45+ import org .slf4j .Logger ;
46+ import org .slf4j .LoggerFactory ;
47+
4548import com .introproventures .graphql .jpa .query .annotation .GraphQLDescription ;
4649import com .introproventures .graphql .jpa .query .annotation .GraphQLIgnore ;
4750import com .introproventures .graphql .jpa .query .schema .GraphQLSchemaBuilder ;
4851import com .introproventures .graphql .jpa .query .schema .JavaScalars ;
4952import com .introproventures .graphql .jpa .query .schema .NamingStrategy ;
5053import com .introproventures .graphql .jpa .query .schema .impl .PredicateFilter .Criteria ;
54+
5155import graphql .Assert ;
5256import graphql .Scalars ;
5357import graphql .schema .Coercing ;
6569import graphql .schema .GraphQLType ;
6670import graphql .schema .GraphQLTypeReference ;
6771import graphql .schema .PropertyDataFetcher ;
68- import org .slf4j .Logger ;
69- import org .slf4j .LoggerFactory ;
7072
7173/**
7274 * JPA specific schema builder implementation of {code #GraphQLSchemaBuilder} interface
@@ -95,7 +97,8 @@ public class GraphQLJpaSchemaBuilder implements GraphQLSchemaBuilder {
9597
9698 private Map <Class <?>, GraphQLType > classCache = new HashMap <>();
9799 private Map <EntityType <?>, GraphQLObjectType > entityCache = new HashMap <>();
98- private Map <EmbeddableType <?>, GraphQLObjectType > embeddableCache = new HashMap <>();
100+ private Map <EmbeddableType <?>, GraphQLObjectType > embeddableOutputCache = new HashMap <>();
101+ private Map <EmbeddableType <?>, GraphQLInputObjectType > embeddableInputCache = new HashMap <>();
99102
100103 private static final Logger log = LoggerFactory .getLogger (GraphQLJpaSchemaBuilder .class );
101104
@@ -292,13 +295,13 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
292295 .field (GraphQLInputObjectField .newInputObjectField ()
293296 .name (Criteria .EQ .name ())
294297 .description ("Equals criteria" )
295- .type (( GraphQLInputType ) getAttributeType (attribute ))
298+ .type (getAttributeInputType (attribute ))
296299 .build ()
297300 )
298301 .field (GraphQLInputObjectField .newInputObjectField ()
299302 .name (Criteria .NE .name ())
300303 .description ("Not Equals criteria" )
301- .type (( GraphQLInputType ) getAttributeType (attribute ))
304+ .type (getAttributeInputType (attribute ))
302305 .build ()
303306 );
304307
@@ -307,25 +310,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
307310 builder .field (GraphQLInputObjectField .newInputObjectField ()
308311 .name (Criteria .LE .name ())
309312 .description ("Less then or Equals criteria" )
310- .type (( GraphQLInputType ) getAttributeType (attribute ))
313+ .type (getAttributeInputType (attribute ))
311314 .build ()
312315 )
313316 .field (GraphQLInputObjectField .newInputObjectField ()
314317 .name (Criteria .GE .name ())
315318 .description ("Greater or Equals criteria" )
316- .type (( GraphQLInputType ) getAttributeType (attribute ))
319+ .type (getAttributeInputType (attribute ))
317320 .build ()
318321 )
319322 .field (GraphQLInputObjectField .newInputObjectField ()
320323 .name (Criteria .GT .name ())
321324 .description ("Greater Then criteria" )
322- .type (( GraphQLInputType ) getAttributeType (attribute ))
325+ .type (getAttributeInputType (attribute ))
323326 .build ()
324327 )
325328 .field (GraphQLInputObjectField .newInputObjectField ()
326329 .name (Criteria .LT .name ())
327330 .description ("Less Then criteria" )
328- .type (( GraphQLInputType ) getAttributeType (attribute ))
331+ .type (getAttributeInputType (attribute ))
329332 .build ()
330333 );
331334 }
@@ -334,25 +337,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
334337 builder .field (GraphQLInputObjectField .newInputObjectField ()
335338 .name (Criteria .LIKE .name ())
336339 .description ("Like criteria" )
337- .type (( GraphQLInputType ) getAttributeType (attribute ))
340+ .type (getAttributeInputType (attribute ))
338341 .build ()
339342 )
340343 .field (GraphQLInputObjectField .newInputObjectField ()
341344 .name (Criteria .CASE .name ())
342345 .description ("Case sensitive match criteria" )
343- .type (( GraphQLInputType ) getAttributeType (attribute ))
346+ .type (getAttributeInputType (attribute ))
344347 .build ()
345348 )
346349 .field (GraphQLInputObjectField .newInputObjectField ()
347350 .name (Criteria .STARTS .name ())
348351 .description ("Starts with criteria" )
349- .type (( GraphQLInputType ) getAttributeType (attribute ))
352+ .type (getAttributeInputType (attribute ))
350353 .build ()
351354 )
352355 .field (GraphQLInputObjectField .newInputObjectField ()
353356 .name (Criteria .ENDS .name ())
354357 .description ("Ends with criteria" )
355- .type (( GraphQLInputType ) getAttributeType (attribute ))
358+ .type (getAttributeInputType (attribute ))
356359 .build ()
357360 );
358361 }
@@ -373,25 +376,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
373376 .field (GraphQLInputObjectField .newInputObjectField ()
374377 .name (Criteria .IN .name ())
375378 .description ("In criteria" )
376- .type (new GraphQLList (getAttributeType (attribute )))
379+ .type (new GraphQLList (getAttributeInputType (attribute )))
377380 .build ()
378381 )
379382 .field (GraphQLInputObjectField .newInputObjectField ()
380383 .name (Criteria .NIN .name ())
381384 .description ("Not In criteria" )
382- .type (new GraphQLList (getAttributeType (attribute )))
385+ .type (new GraphQLList (getAttributeInputType (attribute )))
383386 .build ()
384387 )
385388 .field (GraphQLInputObjectField .newInputObjectField ()
386389 .name (Criteria .BETWEEN .name ())
387390 .description ("Between criteria" )
388- .type (new GraphQLList (getAttributeType (attribute )))
391+ .type (new GraphQLList (getAttributeInputType (attribute )))
389392 .build ()
390393 )
391394 .field (GraphQLInputObjectField .newInputObjectField ()
392395 .name (Criteria .NOT_BETWEEN .name ())
393396 .description ("Not Between criteria" )
394- .type (new GraphQLList (getAttributeType (attribute )))
397+ .type (new GraphQLList (getAttributeInputType (attribute )))
395398 .build ()
396399 );
397400
@@ -404,39 +407,52 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
404407 }
405408
406409 private GraphQLArgument getArgument (Attribute <?,?> attribute ) {
407- GraphQLType type = getAttributeType (attribute );
410+ GraphQLInputType type = getAttributeInputType (attribute );
408411 String description = getSchemaDescription (attribute .getJavaMember ());
409412
410- if (type instanceof GraphQLInputType ) {
411- return GraphQLArgument .newArgument ()
412- .name (attribute .getName ())
413- .type ((GraphQLInputType ) type )
414- .description (description )
415- .build ();
416- }
417-
418- throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Input Argument" );
413+ return GraphQLArgument .newArgument ()
414+ .name (attribute .getName ())
415+ .type ((GraphQLInputType ) type )
416+ .description (description )
417+ .build ();
419418 }
420419
421- private GraphQLObjectType getEmbeddableType (EmbeddableType <?> embeddableType ) {
422- if (embeddableCache .containsKey (embeddableType ))
423- return embeddableCache .get (embeddableType );
424-
425- String embeddableTypeName = namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ())+"EmbeddableType" ;
426-
427- GraphQLObjectType objectType = GraphQLObjectType .newObject ()
428- .name (embeddableTypeName )
429- .description (getSchemaDescription ( embeddableType .getJavaType ()))
430- .fields (embeddableType .getAttributes ().stream ()
431- .filter (this ::isNotIgnored )
432- .map (this ::getObjectField )
433- .collect (Collectors .toList ())
434- )
435- .build ();
436-
437- embeddableCache .putIfAbsent (embeddableType , objectType );
420+ private GraphQLType getEmbeddableType (EmbeddableType <?> embeddableType , boolean input ) {
421+ if (input && embeddableInputCache .containsKey (embeddableType ))
422+ return embeddableInputCache .get (embeddableType );
423+
424+ if (!input && embeddableOutputCache .containsKey (embeddableType ))
425+ return embeddableOutputCache .get (embeddableType );
426+ String embeddableTypeName = namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ())+ (input ? "Input" : "" ) +"EmbeddableType" ;
427+ GraphQLType graphQLType =null ;
428+ if (input ) {
429+ graphQLType = GraphQLInputObjectType .newInputObject ()
430+ .name (embeddableTypeName )
431+ .description (getSchemaDescription (embeddableType .getJavaType ()))
432+ .fields (embeddableType .getAttributes ().stream ()
433+ .filter (this ::isNotIgnored )
434+ .map (this ::getInputObjectField )
435+ .collect (Collectors .toList ())
436+ )
437+ .build ();
438+ } else {
439+ graphQLType = GraphQLObjectType .newObject ()
440+ .name (embeddableTypeName )
441+ .description (getSchemaDescription (embeddableType .getJavaType ()))
442+ .fields (embeddableType .getAttributes ().stream ()
443+ .filter (this ::isNotIgnored )
444+ .map (this ::getObjectField )
445+ .collect (Collectors .toList ())
446+ )
447+ .build ();
448+ }
449+ if (input ) {
450+ embeddableInputCache .putIfAbsent (embeddableType , (GraphQLInputObjectType ) graphQLType );
451+ } else {
452+ embeddableOutputCache .putIfAbsent (embeddableType , (GraphQLObjectType ) graphQLType );
453+ }
438454
439- return objectType ;
455+ return graphQLType ;
440456 }
441457
442458
@@ -462,67 +478,92 @@ private GraphQLObjectType getObjectType(EntityType<?> entityType) {
462478
463479 @ SuppressWarnings ( { "rawtypes" , "unchecked" } )
464480 private GraphQLFieldDefinition getObjectField (Attribute attribute ) {
465- GraphQLType type = getAttributeType (attribute );
466-
467- if (type instanceof GraphQLOutputType ) {
468- List <GraphQLArgument > arguments = new ArrayList <>();
469- DataFetcher dataFetcher = PropertyDataFetcher .fetching (attribute .getName ());
470-
471- // Only add the orderBy argument for basic attribute types
472- if (attribute instanceof SingularAttribute
473- && attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC ) {
474- arguments .add (GraphQLArgument .newArgument ()
475- .name (ORDER_BY_PARAM_NAME )
476- .description ("Specifies field sort direction in the query results." )
477- .type (orderByDirectionEnum )
478- .build ()
479- );
480- }
481-
482- // Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
483- if (attribute instanceof SingularAttribute
484- && attribute .getPersistentAttributeType () != Attribute .PersistentAttributeType .BASIC ) {
485- ManagedType foreignType = (ManagedType ) ((SingularAttribute ) attribute ).getType ();
486-
487- // TODO fix page count query
488- arguments .add (getWhereArgument (foreignType ));
489-
490- } // Get Sub-Objects fields queries via DataFetcher
491- else if (attribute instanceof PluralAttribute
492- && (attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ONE_TO_MANY
493- || attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .MANY_TO_MANY )) {
494- EntityType declaringType = (EntityType ) ((PluralAttribute ) attribute ).getDeclaringType ();
495- EntityType elementType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
496-
497- arguments .add (getWhereArgument (elementType ));
498- dataFetcher = new GraphQLJpaOneToManyDataFetcher (entityManager , declaringType , (PluralAttribute ) attribute );
499- }
481+ GraphQLOutputType type = getAttributeOutputType (attribute );
482+
483+ List <GraphQLArgument > arguments = new ArrayList <>();
484+ DataFetcher dataFetcher = PropertyDataFetcher .fetching (attribute .getName ());
485+
486+ // Only add the orderBy argument for basic attribute types
487+ if (attribute instanceof SingularAttribute
488+ && attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC ) {
489+ arguments .add (GraphQLArgument .newArgument ()
490+ .name (ORDER_BY_PARAM_NAME )
491+ .description ("Specifies field sort direction in the query results." )
492+ .type (orderByDirectionEnum )
493+ .build ()
494+ );
495+ }
500496
501- return GraphQLFieldDefinition .newFieldDefinition ()
502- .name (attribute .getName ())
503- .description (getSchemaDescription (attribute .getJavaMember ()))
504- .type ((GraphQLOutputType ) type )
505- .dataFetcher (dataFetcher )
506- .argument (arguments )
507- .build ();
497+ // Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
498+ if (attribute instanceof SingularAttribute
499+ && attribute .getPersistentAttributeType () != Attribute .PersistentAttributeType .BASIC ) {
500+ ManagedType foreignType = (ManagedType ) ((SingularAttribute ) attribute ).getType ();
501+
502+ // TODO fix page count query
503+ arguments .add (getWhereArgument (foreignType ));
504+
505+ } // Get Sub-Objects fields queries via DataFetcher
506+ else if (attribute instanceof PluralAttribute
507+ && (attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ONE_TO_MANY
508+ || attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .MANY_TO_MANY )) {
509+ EntityType declaringType = (EntityType ) ((PluralAttribute ) attribute ).getDeclaringType ();
510+ EntityType elementType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
511+
512+ arguments .add (getWhereArgument (elementType ));
513+ dataFetcher = new GraphQLJpaOneToManyDataFetcher (entityManager , declaringType , (PluralAttribute ) attribute );
508514 }
509515
510- throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Output Argument" );
516+ return GraphQLFieldDefinition .newFieldDefinition ()
517+ .name (attribute .getName ())
518+ .description (getSchemaDescription (attribute .getJavaMember ()))
519+ .type (type )
520+ .dataFetcher (dataFetcher )
521+ .argument (arguments )
522+ .build ();
523+ }
524+
525+ @ SuppressWarnings ( { "rawtypes" , "unchecked" } )
526+ private GraphQLInputObjectField getInputObjectField (Attribute attribute ) {
527+ GraphQLInputType type = getAttributeInputType (attribute );
528+
529+ return GraphQLInputObjectField .newInputObjectField ()
530+ .name (attribute .getName ())
531+ .description (getSchemaDescription (attribute .getJavaMember ()))
532+ .type (type )
533+ .build ();
511534 }
512535
513536 private Stream <Attribute <?,?>> findBasicAttributes (Collection <Attribute <?,?>> attributes ) {
514537 return attributes .stream ().filter (it -> it .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC );
515538 }
516539
517540 @ SuppressWarnings ( "rawtypes" )
518- private GraphQLType getAttributeType (Attribute <?,?> attribute ) {
541+ private GraphQLInputType getAttributeInputType (Attribute <?,?> attribute ) {
542+ try {
543+ return (GraphQLInputType ) getAttributeType (attribute , true );
544+ } catch (ClassCastException e ){
545+ throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Input Argument" );
546+ }
547+ }
548+
549+ @ SuppressWarnings ( "rawtypes" )
550+ private GraphQLOutputType getAttributeOutputType (Attribute <?,?> attribute ) {
551+ try {
552+ return (GraphQLOutputType ) getAttributeType (attribute , false );
553+ } catch (ClassCastException e ){
554+ throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Output Argument" );
555+ }
556+ }
557+
558+ @ SuppressWarnings ( "rawtypes" )
559+ private GraphQLType getAttributeType (Attribute <?,?> attribute , boolean input ) {
519560
520561 if (isBasic (attribute )) {
521562 return getGraphQLTypeFromJavaType (attribute .getJavaType ());
522563 }
523564 else if (isEmbeddable (attribute )) {
524565 EmbeddableType embeddableType = (EmbeddableType ) ((SingularAttribute ) attribute ).getType ();
525- return getEmbeddableType (embeddableType );
566+ return getEmbeddableType (embeddableType , input );
526567 }
527568 else if (isToMany (attribute )) {
528569 EntityType foreignType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
@@ -572,7 +613,8 @@ protected final boolean isToOne(Attribute<?,?> attribute) {
572613
573614 protected final boolean isValidInput (Attribute <?,?> attribute ) {
574615 return attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC ||
575- attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ELEMENT_COLLECTION ;
616+ attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ELEMENT_COLLECTION ||
617+ attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .EMBEDDED ;
576618 }
577619
578620 private String getSchemaDescription (Member member ) {
0 commit comments