Implementing OAuth2 in Spring Boot - Flexible Redirects

Java
Spring Boot

Today

I've come to the fact that most of the times in Spring Boot applications with separate frontend authentication, The redirect URL is not always the same. The best way I've found out is for frontend application to provide some kind of fe_redirect_url parameter with the backend verifying if the user is allowed to do so. here's my step-by-step guide on integrating this into your Spring Boot application.

Take a look at the previous article here if you need a refresher on how to create a basic jwt authentication

Theory

The basic theory is that when users come to our OAuth request, we put it inside a cookie, and then when OAuth2 flow is done (whether it was successful or not), we read the cookie and redirect the user accordingly.

0. application.yml

Before we start, let's configure application.yml for OAuth2 purposes

spring:
  security:
    oauth2:
      client:
        registration:
          google:
            clientId: ${GOOGLE_CLIENT_ID}
            clientSecret: ${GOOGLE_CLIENT_SECRET}
            scope: openid,profile,email

1. Create cookie utils

When frontend sends a request, we store the frontend redirect url in as cookies

package com.prizm.service.auth;
 
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.util.SerializationUtils;
 
import java.util.Base64;
import java.util.Optional;
 
public class CookieUtils {
 
  /**
   * Retrieves a cookie by name from the request.
   *
   * @param request The HTTP request.
   * @param name    The name of the cookie.
   * @return An Optional containing the cookie if found, otherwise empty.
   */
  public static Optional<Cookie> getCookie(HttpServletRequest request, String name) {
    Cookie[] cookies = request.getCookies();
 
    if (cookies != null) {
      for (Cookie cookie : cookies) {
        if (cookie.getName().equals(name)) {
          return Optional.of(cookie);
        }
      }
    }
 
    return Optional.empty();
  }
 
  /**
   * Adds a new cookie to the HTTP response.
   *
   * @param response The HTTP response.
   * @param name     The name of the cookie.
   * @param value    The value of the cookie.
   * @param maxAge   The maximum age in seconds.
   */
  public static void addCookie(HttpServletResponse response, String name, String value, int maxAge) {
    Cookie cookie = new Cookie(name, value);
    cookie.setPath("/");
    cookie.setHttpOnly(true);
    cookie.setMaxAge(maxAge);
    response.addCookie(cookie);
  }
 
  /**
   * Deletes a cookie by setting its max age to 0.
   *
   * @param request  The HTTP request.
   * @param response The HTTP response.
   * @param name     The name of the cookie to delete.
   */
  public static void deleteCookie(HttpServletRequest request, HttpServletResponse response, String name) {
    Cookie[] cookies = request.getCookies();
    if (cookies != null && cookies.length > 0) {
      for (Cookie cookie : cookies) {
        if (cookie.getName().equals(name)) {
          cookie.setValue("");
          cookie.setPath("/");
          cookie.setMaxAge(0);
          response.addCookie(cookie);
        }
      }
    }
  }
 
  /**
   * Serializes an object into a Base64 encoded string.
   *
   * @param object The object to serialize.
   * @return A Base64 encoded string representation of the object.
   */
  public static String serialize(Object object) {
    return Base64.getUrlEncoder()
            .encodeToString(SerializationUtils.serialize(object));
  }
 
  /**
   * Deserializes a Base64 encoded string back into an object.
   *
   * @param cookie The cookie containing the Base64 string.
   * @param cls    The class to deserialize into.
   * @return The deserialized object.
   */
  public static <T> T deserialize(Cookie cookie, Class<T> cls) {
    byte[] data = Base64.getUrlDecoder().decode(cookie.getValue());
    return cls.cast(SerializationUtils.deserialize(data));
  }
}

2. Create OAuth2 Handlers

CustomOAuth2FailureHandler
package com.prizm.service.auth;
 
import static com.prizm.service.auth.CustomOAuth2RequestHandler.REDIRECT_URI_PARAM_COOKIE_NAME;
 
import com.prizm.service.user.UserService;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import org.springframework.stereotype.Service;
 
@Slf4j
@Service
@RequiredArgsConstructor
public class CustomOAuth2FailureHandler extends SimpleUrlAuthenticationFailureHandler {
  private final UserService userService;
 
  @Override
  public void onAuthenticationFailure(
      HttpServletRequest request, HttpServletResponse response, AuthenticationException exception)
      throws IOException {
    log.error("Authentication failed", exception);
    getRedirectStrategy()
        .sendRedirect(
            request,
            response,
            determineTargetUrl(request) + "/auth/callback-failure?error=auth_failed");
  }
 
  protected String determineTargetUrl(HttpServletRequest request) {
 
    return CookieUtils.getCookie(request, REDIRECT_URI_PARAM_COOKIE_NAME)
        .orElseThrow(
            () ->
                new IllegalArgumentException(
                    "Sorry! We've got an Unauthorized Redirect URI and can't proceed with the authentication."))
        .getValue();
  }
}
CustomOAuth2RequestHandler
package com.prizm.service.auth;
 
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.stereotype.Component;
 
@Slf4j
@Component
public class CustomOAuth2RequestHandler
    implements AuthorizationRequestRepository<OAuth2AuthorizationRequest> {
  public static final String OAUTH2_AUTHORIZATION_REQUEST_COOKIE_NAME = "oauth2_auth_request";
  public static final String REDIRECT_URI_PARAM_COOKIE_NAME = "fe_redirect_uri";
  private static final int cookieExpireSeconds = 180;
 
  @Override
  public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
    return CookieUtils.getCookie(request, OAUTH2_AUTHORIZATION_REQUEST_COOKIE_NAME)
        .map(cookie -> CookieUtils.deserialize(cookie, OAuth2AuthorizationRequest.class))
        .orElse(null);
  }
 
  @Override
  public void saveAuthorizationRequest(
      OAuth2AuthorizationRequest authorizationRequest,
      HttpServletRequest request,
      HttpServletResponse response) {
    if (authorizationRequest == null) {
      removeAuthorizationRequestCookies(request, response);
      return;
    }
    CookieUtils.addCookie(
        response,
        OAUTH2_AUTHORIZATION_REQUEST_COOKIE_NAME,
        CookieUtils.serialize(authorizationRequest),
        cookieExpireSeconds);
    String redirectUriAfterLogin = request.getParameter(REDIRECT_URI_PARAM_COOKIE_NAME);
    if (Objects.nonNull(redirectUriAfterLogin) && !redirectUriAfterLogin.isEmpty()) {
      // TODO: Put your regex validation here to check if redirect uri is correct.
      CookieUtils.addCookie(
          response, REDIRECT_URI_PARAM_COOKIE_NAME, redirectUriAfterLogin, cookieExpireSeconds);
    } else {
      throw new IllegalArgumentException(
          "Sorry! We've got an Unauthorized Redirect URI and can't proceed with the authentication.");
    }
  }
 
  @Override
  public OAuth2AuthorizationRequest removeAuthorizationRequest(
      HttpServletRequest request, HttpServletResponse response) {
    return this.loadAuthorizationRequest(request);
  }
 
  public void removeAuthorizationRequestCookies(
      HttpServletRequest request, HttpServletResponse response) {
    CookieUtils.deleteCookie(request, response, OAUTH2_AUTHORIZATION_REQUEST_COOKIE_NAME);
    CookieUtils.deleteCookie(request, response, REDIRECT_URI_PARAM_COOKIE_NAME);
  }
}
CustomOAuth2Service
package com.prizm.service.auth;
 
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
 
@Slf4j
@Service
@RequiredArgsConstructor
public class CustomOAuth2Service implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
 
  @Override
  public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
    var delegate = new DefaultOAuth2UserService();
    OAuth2User oAuth2User = delegate.loadUser(userRequest);
 
    String registrationId = userRequest.getClientRegistration().getRegistrationId();
 
    String userNameAttributeName =
        userRequest
            .getClientRegistration()
            .getProviderDetails()
            .getUserInfoEndpoint()
            .getUserNameAttributeName();
    Map<String, Object> attributes = oAuth2User.getAttributes();
 
    HashMap<String, Object> userAttributes = new HashMap<>();
    userAttributes.put("id", attributes.get("id"));
    userAttributes.put("email", attributes.get("email"));
 
    return new DefaultOAuth2User(
        Collections.singleton(new SimpleGrantedAuthority("USER")),
        userAttributes,
        userNameAttributeName);
  }
}
CustomOauth2SuccessHandler
package com.prizm.service.auth;
 
import static com.prizm.service.auth.CustomOAuth2RequestHandler.REDIRECT_URI_PARAM_COOKIE_NAME;
 
import com.prizm.service.user.UserService;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
import org.springframework.stereotype.Service;
 
@Slf4j
@Service
@RequiredArgsConstructor
public class CustomOAuth2SuccessHandler extends SimpleUrlAuthenticationSuccessHandler {
 
  private final UserService userService;
 
  @Override
  public void onAuthenticationSuccess(
      HttpServletRequest request, HttpServletResponse response, Authentication authentication)
      throws IOException {
    OAuth2User oAuth2User = (OAuth2User) authentication.getPrincipal();
    String email = oAuth2User.getAttribute("email");
    log.info("email = {}", email);
 
    var targetUrl = determineTargetUrl(request);
 
    if (Objects.isNull(email) || email.isEmpty()) {
      getRedirectStrategy()
          .sendRedirect(request, response, targetUrl + "/auth/callback-failure?error=no_email");
    }
 
    // TODO: Replace this is with your user service
    var userTokens = userService.loginOrCreateUser(email);
 
    getRedirectStrategy()
        .sendRedirect(
            request,
            response,
            targetUrl
                + "/auth/callback?token="
                + userTokens.accessToken()
                + "&refreshToken="
                + userTokens.refreshToken());
  }
 
  // This method is used to determine the target URL after successful authentication.
  protected String determineTargetUrl(HttpServletRequest request) {
 
    return CookieUtils.getCookie(request, REDIRECT_URI_PARAM_COOKIE_NAME)
        .orElseThrow(
            () ->
                new IllegalArgumentException(
                    "Sorry! We've got an Unauthorized Redirect URI and can't proceed with the authentication."))
        .getValue();
  }
}

3. Putting it all together inside SecurityConfig.java

package com.prizm.service.config;
 
import com.prizm.service.auth.CustomOAuth2FailureHandler;
import com.prizm.service.auth.CustomOAuth2Service;
import com.prizm.service.auth.CustomOAuth2SuccessHandler;
import java.util.List;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
 
@Configuration
@EnableWebSecurity
@EnableMethodSecurity
@RequiredArgsConstructor
public class SecurityConfig {
 
  @Bean
  public SecurityFilterChain filterChain(
      CustomOAuth2Service service,
      CustomOAuth2SuccessHandler successHandler,
      CustomOAuth2FailureHandler failureHandler,
      AuthorizationRequestRepository<OAuth2AuthorizationRequest> customOAuth2RequestHandler,
      HttpSecurity http)
      throws Exception {
    return http.csrf(AbstractHttpConfigurer::disable)
        .cors(
            httpSecurityCorsConfigurer ->
                httpSecurityCorsConfigurer.configurationSource(corsConfigurationSource()))
        .authorizeHttpRequests(configurer -> configurer.anyRequest().permitAll())
        .oauth2ResourceServer(configurer -> configurer.jwt(Customizer.withDefaults()))
        .oauth2Login(
            (oauth2Login) ->
                oauth2Login
                    .authorizationEndpoint(
                        auth -> auth.authorizationRequestRepository(customOAuth2RequestHandler))
                    .userInfoEndpoint(userInfo -> userInfo.userService(service))
                    .successHandler(successHandler)
                    .failureHandler(failureHandler))
        .build();
  }
 
  @Bean
  CorsConfigurationSource corsConfigurationSource() {
    var configuration = new CorsConfiguration();
    configuration.setAllowedOrigins(List.of("*"));
    configuration.setAllowedMethods(List.of("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"));
    configuration.setAllowedHeaders(List.of("*"));
    var source = new UrlBasedCorsConfigurationSource();
    source.registerCorsConfiguration("/**", configuration);
    return source;
  }
 
  @Bean
  public JwtDecoder jwtDecoder(JwtPropertiesService jwtPropertiesService) {
    return NimbusJwtDecoder.withSecretKey(jwtPropertiesService.getKey()).build();
  }
}