diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 329ece6af75..885c7ca21a7 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -4506,7 +4506,9 @@ private ReactiveOAuth2UserService getOidcUserService( if (bean != null) { return bean; } - return new OidcReactiveOAuth2UserService(); + OidcReactiveOAuth2UserService reactiveOAuth2UserService = new OidcReactiveOAuth2UserService(); + reactiveOAuth2UserService.setOauth2UserService(getOauth2UserService()); + return reactiveOAuth2UserService; } private ReactiveOAuth2UserService getOauth2UserService() { diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 67d7a816f01..4597d3e7865 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -64,6 +64,8 @@ import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; @@ -84,6 +86,7 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.TestOAuth2Users; import org.springframework.security.oauth2.jwt.Jwt; @@ -664,6 +667,41 @@ public void oauth2LoginWhenDefaultsThenNoOidcSessionRegistry() { .block()).isEmpty(); } + @Test + public void oauth2LoginWhenOauth2UserServiceBeanPresent() { + this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class, OAuth2LoginWithOauth2UserService.class) + .autowire(); + WebTestClient webTestClient = WebTestClientBuilder.bindToWebFilters(this.springSecurity).build(); + OAuth2LoginWithOauth2UserService config = this.spring.getContext() + .getBean(OAuth2LoginWithOauth2UserService.class); + OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build(); + OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build(); + OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid"); + OAuth2AuthorizationCodeAuthenticationToken token = new OAuth2AuthorizationCodeAuthenticationToken(google, + exchange, accessToken); + ServerAuthenticationConverter converter = config.authenticationConverter; + given(converter.convert(any())).willReturn(Mono.just(token)); + ServerSecurityContextRepository securityContextRepository = config.securityContextRepository; + given(securityContextRepository.save(any(), any())).willReturn(Mono.empty()); + given(securityContextRepository.load(any())).willReturn(authentication(token)); + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()) + .additionalParameters(additionalParameters) + .build(); + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = config.tokenResponseClient; + given(tokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + ReactiveOAuth2UserService userService = config.reactiveOAuth2UserService; + given(userService.loadUser(any())).willReturn(Mono + .just(new DefaultOAuth2User(AuthorityUtils.createAuthorityList("USER"), Map.of("sub", "subject"), "sub"))); + webTestClient.get().uri("/login/oauth2/code/google").exchange().expectStatus().is3xxRedirection(); + verify(userService).loadUser(any()); + + } + Mono authentication(Authentication authentication) { SecurityContext context = new SecurityContextImpl(); context.setAuthentication(authentication); @@ -674,6 +712,51 @@ T getBean(Class beanClass) { return this.spring.getContext().getBean(beanClass); } + @Configuration + static class OAuth2LoginWithOauth2UserService { + + ReactiveOAuth2AccessTokenResponseClient tokenResponseClient = mock( + ReactiveOAuth2AccessTokenResponseClient.class); + + ReactiveOAuth2UserService reactiveOAuth2UserService = mock( + DefaultReactiveOAuth2UserService.class); + + ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class); + + ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class); + + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + http.authorizeExchange((authorize) -> authorize.anyExchange().authenticated()) + .oauth2Login((c) -> c.authenticationConverter(this.authenticationConverter) + .securityContextRepository(this.securityContextRepository)); + return http.build(); + } + + @Bean + ReactiveOAuth2UserService customOAuth2UserService() { + return this.reactiveOAuth2UserService; + } + + @Bean + ReactiveJwtDecoderFactory jwtDecoderFactory() { + return (clientRegistration) -> (token) -> { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.SUB, "subject"); + claims.put(IdTokenClaimNames.ISS, "http://localhost/issuer"); + claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client")); + claims.put(IdTokenClaimNames.AZP, "client"); + return Mono.just(TestJwts.jwt().claims((c) -> c.putAll(claims)).build()); + }; + } + + @Bean + ReactiveOAuth2AccessTokenResponseClient requestReactiveOAuth2AccessTokenResponseClient() { + return this.tokenResponseClient; + } + + } + @Configuration @EnableWebFluxSecurity static class OAuth2LoginWithMultipleClientRegistrations {