diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java index fd96945b..30534c71 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java @@ -17,6 +17,7 @@ package org.springframework.graphql.server.webmvc; import java.io.IOException; +import java.time.Duration; import java.util.Map; import java.util.function.Consumer; @@ -31,6 +32,7 @@ import org.springframework.graphql.execution.SubscriptionPublisherException; import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.graphql.server.WebGraphQlResponse; +import org.springframework.lang.Nullable; import org.springframework.web.context.request.async.AsyncRequestTimeoutException; import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; @@ -47,8 +49,29 @@ */ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { + @Nullable + private final Duration timeout; + + + /** + * Constructor with the handler to delegate to, and no timeout, + * i.e. relying on underlying Server async request timeout. + * @param graphQlHandler the handler to delegate to + */ public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) { + this(graphQlHandler, null); + } + + /** + * Variant constructor with a timeout to use for SSE subscriptions. + * @param graphQlHandler the handler to delegate to + * @param timeout the timeout value to set on + * {@link org.springframework.web.context.request.async.AsyncWebRequest#setTimeout(Long)} + * @since 1.3.3 + */ + public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) { super(graphQlHandler, null); + this.timeout = timeout; } @@ -76,7 +99,9 @@ protected ServerResponse prepareResponse( .toSpecification()); }); - return ServerResponse.sse(SseSubscriber.connect(resultFlux)); + return ((this.timeout != null) ? + ServerResponse.sse(SseSubscriber.connect(resultFlux), this.timeout) : + ServerResponse.sse(SseSubscriber.connect(resultFlux))); }