【Spring Security】 Oath2 授权服务器异常处理源码分析

Metadata

title: 【Spring Security】 Oath2 授权服务器异常处理源码分析
date: 2023-02-05 10:00
tags:
  - 行动阶段/完成
  - 主题场景/组件
  - 笔记空间/KnowladgeSpace/ProgramSpace/ModuleSpace
  - 细化主题/Module/SpringSecurity
categories:
  - SpringSecurity
keywords:
  - SpringSecurity
description: 【Spring Security】 Oath2 授权服务器异常处理源码分析

【Spring Security】 Oath2 授权服务器异常处理源码分析

在之前 Security 文档中,了解到在异常时框架自己会捕获,然后通过异常处理器处理,使用 @RestControllerAdvice 在这里是不能统一处理的,因为这个注解是对 controller 层进行拦截。

在 Spring Security Oauth2 中,这些异常也是框架自行捕获处理了,下面跟踪源码分析下。

授权服务器异常

OAuth2Exception

OAuth2Exception 类就是 Oauth2 的异常类,继承自 RuntimeException。

其定义了很多常量表示错误信息,基本上对应每个 OAuth2Exception 的子类。

    // 错误 
    public static final String ERROR = "error";
    // 错误描述 
    public static final String DESCRIPTION = "error_description";
    // 错误的URI
    public static final String URI = "error_uri";
    // 无效的请求 InvalidRequestException
    public static final String INVALID_REQUEST = "invalid_request";
    // 无效客户端
    public static final String INVALID_CLIENT = "invalid_client";
    // 无效授权 InvalidGrantException
    public static final String INVALID_GRANT = "invalid_grant";
    // 未经授权的客户端
    public static final String UNAUTHORIZED_CLIENT = "unauthorized_client";
    // 不受支持的授权类型 UnsupportedGrantTypeException
    public static final String UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type";
    // 无效授权范围 InvalidScopeException
    public static final String INVALID_SCOPE = "invalid_scope";
    // 授权范围不足
    public static final String INSUFFICIENT_SCOPE = "insufficient_scope";
    // 令牌无效 InvalidTokenException
    public static final String INVALID_TOKEN = "invalid_token";
    // 重定向uri不匹配 RedirectMismatchException
    public static final String REDIRECT_URI_MISMATCH ="redirect_uri_mismatch";
    // 不支持的响应类型 UnsupportedResponseTypeException
    public static final String UNSUPPORTED_RESPONSE_TYPE ="unsupported_response_type";
    // 拒绝访问 UserDeniedAuthorizationException
    public static final String ACCESS_DENIED = "access_denied";

OAuth2Exception 也定义了很多方法:

    // 添加额外的异常信息
    private Map<String, String> additionalInformation = null;
    // OAuth2 错误代码
    public String getOAuth2ErrorCode() {
        return "invalid_request";
    }
    // 与此错误关联的 HTTP 错误代码
    public int getHttpErrorCode() {
        return 400;
    }
    // 根据定义好的错误代码(常量),创建对应的OAuth2Exception子类
    public static OAuth2Exception create(String errorCode, String errorMessage) {
        if (errorMessage == null) {
            errorMessage = errorCode == null ? "OAuth Error" : errorCode;
        }
        if (INVALID_CLIENT.equals(errorCode)) {
            return new InvalidClientException(errorMessage);
        }
        // 省略.......
    }
    // 从 Map<String,String> 创建一个 {@link OAuth2Exception}。
    public static OAuth2Exception valueOf(Map<String, String> errorParams) {
        // 省略.......
        return ex;
    }
    

    /**
     * @return 以逗号分隔的详细信息列表(键值对)
     */
    public String getSummary() {
        // 省略.......
        return builder.toString();
    }

异常处理源码分析

我们以密码模式,不传入授权类型为例。

1. 端点校验 GrantType 抛出异常

密码模式访问 / oauth/token 端点,在下面代码中,GrantType 不存在,会抛出 InvalidRequestException 异常,这个异常的 msg 为 “Missing grant type”。

创建的异常,包含了下面这些信息。

2. 端点中的 @ExceptionHandler 统一处理异常

在端点类中,定义了多个 @ExceptionHandler,所以只要是在这个端点中的异常,都会被捕获处理。

1 中抛出的 InvalidRequestException 是 OAuth2Exception 的子类,所以最终由下面这个 ExceptionHandler 处理。

    @ExceptionHandler(OAuth2Exception.class)
    public ResponseEntity<OAuth2Exception> handleException(OAuth2Exception e) throws Exception {
    // 打印WARN日志
        if (logger.isWarnEnabled()) {
            logger.warn("Handling error: " + e.getClass().getSimpleName() + ", " + e.getMessage());
        }
        // 调用异常翻译器
        return getExceptionTranslator().translate(e);
    }

3. 异常翻译处理器

最终调用 WebResponseExceptionTranslator 的实现类,对异常进行翻译封装处理,最后由 Spring MVC 返回 ResponseEntity<OAuth2Exception> 对象。ResponseEntity 实际是一个 HttpEntity,是 Spring WEB 提供了一个封装信息响应给请求的对象。

异常翻译默认使用的是 DefaultWebResponseExceptionTranslator 类,最终进入其 translate 方法。

    @Override
    public ResponseEntity<OAuth2Exception> translate(Exception e) throws Exception {

        // 1. 尝试从堆栈跟踪中提取 SpringSecurityException
        Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e);
        Exception ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain);
        // 2. 获取OAuth2Exception
        if (ase != null) {
            // 3. 获取到了OAuth2Exception,直接处理
            return handleOAuth2Exception((OAuth2Exception) ase);
        }
        
        ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class,
                causeChain);
        if (ase != null) {
            return handleOAuth2Exception(new UnauthorizedException(e.getMessage(), e));
        }

        ase = (AccessDeniedException) throwableAnalyzer
                .getFirstThrowableOfType(AccessDeniedException.class, causeChain);
        if (ase instanceof AccessDeniedException) {
            return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase));
        }

        ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType(
                HttpRequestMethodNotSupportedException.class, causeChain);
        if (ase instanceof HttpRequestMethodNotSupportedException) {
            return handleOAuth2Exception(new MethodNotAllowed(ase.getMessage(), ase));
        }

        return handleOAuth2Exception(new ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e));

    }

真正创建 ResponseEntity 的是 handleOAuth2Exception 方法。

private ResponseEntity<OAuth2Exception> handleOAuth2Exception(OAuth2Exception e) throws IOException {
        // 获取错误码 eg:400 
        int status = e.getHttpErrorCode();
        // 设置响应消息头,禁用缓存
        HttpHeaders headers = new HttpHeaders();
        headers.set("Cache-Control", "no-store");
        headers.set("Pragma", "no-cache");
        // 如果是401,或者是范围不足异常,设置WWW-Authenticate 消息头
        if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
            headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
        }
        // 将异常信息,塞到ResponseEntity的Body中
        ResponseEntity<OAuth2Exception> response = new ResponseEntity<OAuth2Exception>(e, headers,
                HttpStatus.valueOf(status));
        return response;
    }

4. 序列化

因为 OAuth2Exception 上标注了 JsonSerialize 、JsonDeserialize 注解,所以会进行序列化操作。主要是将 OAuth2Exception 中的异常进行序列化处理。

    @Override
    public void serialize(OAuth2Exception value, JsonGenerator jgen, SerializerProvider provider) throws IOException,
            JsonProcessingException {
        jgen.writeStartObject();
        // 序列化error
        jgen.writeStringField("error", value.getOAuth2ErrorCode());
        String errorMessage = value.getMessage();
        if (errorMessage != null) {
            errorMessage = HtmlUtils.htmlEscape(errorMessage);
        }
        // 序列化error_description
        jgen.writeStringField("error_description", errorMessage);
        // 序列化额外的附加信息AdditionalInformation
        if (value.getAdditionalInformation()!=null) {
            for (Entry<String, String> entry : 
            value.getAdditionalInformation().entrySet()) {
                String key = entry.getKey();
                String add = entry.getValue();
                jgen.writeStringField(key, add);				
            }
        }
        jgen.writeEndObject();
    }

5. 前端获取错误信息

最终,OAuth2Exception 经过抛出,ExceptionHandler 捕获,翻译,封装返回 ResponseEntity,序列化处理,就展示给前端了。

自定义授权服务器异常信息

在实际开放中,一般异常都是有固定格式的,OAuth2Exception 直接返回,不是我们想要的,那么我们可以进行改造。

1. 自定义异常

自定义一个异常,继承 OAuth2Exception,并添加序列化

@JsonSerialize(using = MyOauthExceptionJackson2Serializer.class)
@JsonDeserialize(using = MyOAuth2ExceptionJackson2Deserializer.class)
public class MyOAuth2Exception extends OAuth2Exception{

        public MyOAuth2Exception(String msg, Throwable t) {
            super(msg, t);
        }

        public MyOAuth2Exception(String msg) {
            super(msg);
        }
}

2. 编写序列化

参考 OAuth2Exception 的序列化,编写我们自己的异常的序列化与反序列化类。

public class MyOauthExceptionJackson2Serializer extends StdSerializer<MyOAuth2Exception> {


    public MyOauthExceptionJackson2Serializer() {
        super(MyOAuth2Exception.class);
    }


    @Override
    public void serialize(MyOAuth2Exception value, JsonGenerator jgen, SerializerProvider provider) throws IOException,
            JsonProcessingException {
        jgen.writeStartObject();
        //jgen.writeStringField("error", value.getOAuth2ErrorCode());
        String errorMessage = value.getMessage();
        if (errorMessage != null) {
            errorMessage = HtmlUtils.htmlEscape(errorMessage);
        }
        jgen.writeStringField("msg", errorMessage);
        if (value.getAdditionalInformation()!=null) {
            for (Map.Entry<String, String> entry : value.getAdditionalInformation().entrySet()) {
                String key = entry.getKey();
                String add = entry.getValue();
                jgen.writeStringField(key, add);
            }
        }
        jgen.writeEndObject();
    }
}
public class MyOAuth2ExceptionJackson2Deserializer extends StdDeserializer<OAuth2Exception> {

    public MyOAuth2ExceptionJackson2Deserializer() {
        super(OAuth2Exception.class);
    }

    @Override
    public OAuth2Exception deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException,
            JsonProcessingException {

        JsonToken t = jp.getCurrentToken();
        if (t == JsonToken.START_OBJECT) {
            t = jp.nextToken();
        }
        Map<String, Object> errorParams = new HashMap<String, Object>();
        for (; t == JsonToken.FIELD_NAME; t = jp.nextToken()) {
            // Must point to field name
            String fieldName = jp.getCurrentName();
            // And then the value...
            t = jp.nextToken();
            // Note: must handle null explicitly here; value deserializers won't
            Object value;
            if (t == JsonToken.VALUE_NULL) {
                value = null;
            }
            // Some servers might send back complex content
            else if (t == JsonToken.START_ARRAY) {
                value = jp.readValueAs(List.class);
            } else if (t == JsonToken.START_OBJECT) {
                value = jp.readValueAs(Map.class);
            } else {
                value = jp.getText();
            }
            errorParams.put(fieldName, value);
        }

        Object errorCode = errorParams.get("error");
        String errorMessage = errorParams.get("error_description") != null ? errorParams.get("error_description").toString() : null;
        if (errorMessage == null) {
            errorMessage = errorCode == null ? "OAuth Error" : errorCode.toString();
        }

        OAuth2Exception ex;
        if ("invalid_client".equals(errorCode)) {
            ex = new InvalidClientException(errorMessage);
        } else if ("unauthorized_client".equals(errorCode)) {
            ex = new UnauthorizedClientException(errorMessage);
        } else if ("invalid_grant".equals(errorCode)) {
            if (errorMessage.toLowerCase().contains("redirect") && errorMessage.toLowerCase().contains("match")) {
                ex = new RedirectMismatchException(errorMessage);
            } else {
                ex = new InvalidGrantException(errorMessage);
            }
        } else if ("invalid_scope".equals(errorCode)) {
            ex = new InvalidScopeException(errorMessage);
        } else if ("invalid_token".equals(errorCode)) {
            ex = new InvalidTokenException(errorMessage);
        } else if ("invalid_request".equals(errorCode)) {
            ex = new InvalidRequestException(errorMessage);
        } else if ("redirect_uri_mismatch".equals(errorCode)) {
            ex = new RedirectMismatchException(errorMessage);
        } else if ("unsupported_grant_type".equals(errorCode)) {
            ex = new UnsupportedGrantTypeException(errorMessage);
        } else if ("unsupported_response_type".equals(errorCode)) {
            ex = new UnsupportedResponseTypeException(errorMessage);
        } else if ("insufficient_scope".equals(errorCode)) {
            ex = new InsufficientScopeException(errorMessage, OAuth2Utils.parseParameterList((String) errorParams
                    .get("scope")));
        } else if ("access_denied".equals(errorCode)) {
            ex = new UserDeniedAuthorizationException(errorMessage);
        } else {
            ex = new OAuth2Exception(errorMessage);
        }

        Set<Map.Entry<String, Object>> entries = errorParams.entrySet();
        for (Map.Entry<String, Object> entry : entries) {
            String key = entry.getKey();
            if (!"error".equals(key) && !"error_description".equals(key)) {
                Object value = entry.getValue();
                ex.addAdditionalInformation(key, value == null ? null : value.toString());
            }
        }
        return ex;
    }
}

3. 自定义异常翻译器

也是模仿框架写一个。

@Component
public class MyWebResponseExceptionTranslator implements WebResponseExceptionTranslator<OAuth2Exception> {

    private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();

    @Override
    public ResponseEntity<OAuth2Exception> translate(Exception e) throws Exception {

        // Try to extract a SpringSecurityException from the stacktrace
        Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e);
        Exception ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain);

        if (ase != null) {
            return handleOAuth2Exception((OAuth2Exception) ase);
        }

        ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class,
                causeChain);
        if (ase != null) {
            return handleOAuth2Exception(new MyWebResponseExceptionTranslator.UnauthorizedException(e.getMessage(), e));
        }

        ase = (AccessDeniedException) throwableAnalyzer
                .getFirstThrowableOfType(AccessDeniedException.class, causeChain);
        if (ase instanceof AccessDeniedException) {
            return handleOAuth2Exception(new MyWebResponseExceptionTranslator.ForbiddenException(ase.getMessage(), ase));
        }

        ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType(
                HttpRequestMethodNotSupportedException.class, causeChain);
        if (ase instanceof HttpRequestMethodNotSupportedException) {
            return handleOAuth2Exception(new MyWebResponseExceptionTranslator.MethodNotAllowed(ase.getMessage(), ase));
        }

        return handleOAuth2Exception(new MyWebResponseExceptionTranslator.ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e));

    }

    private ResponseEntity<OAuth2Exception> handleOAuth2Exception(OAuth2Exception e) throws IOException {

        int status = e.getHttpErrorCode();
        HttpHeaders headers = new HttpHeaders();
        headers.set("Cache-Control", "no-store");
        headers.set("Pragma", "no-cache");
        if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
            headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
        }
        //自定义异常信息
        MyOAuth2Exception myOAuth2Exception=new MyOAuth2Exception(e.getMessage());
        myOAuth2Exception.addAdditionalInformation("code", "401");
        myOAuth2Exception.addAdditionalInformation("result", "操作失败");

        ResponseEntity<OAuth2Exception> response = new ResponseEntity<OAuth2Exception>(myOAuth2Exception, headers,
                HttpStatus.valueOf(status));

        return response;

    }

    public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
        this.throwableAnalyzer = throwableAnalyzer;
    }

    @SuppressWarnings("serial")
    private static class ForbiddenException extends OAuth2Exception {

        public ForbiddenException(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "access_denied";
        }

        @Override
        public int getHttpErrorCode() {
            return 403;
        }

    }

    @SuppressWarnings("serial")
    private static class ServerErrorException extends OAuth2Exception {

        public ServerErrorException(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "server_error";
        }

        @Override
        public int getHttpErrorCode() {
            return 500;
        }

    }

    @SuppressWarnings("serial")
    private static class UnauthorizedException extends OAuth2Exception {

        public UnauthorizedException(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "unauthorized";
        }

        @Override
        public int getHttpErrorCode() {
            return 401;
        }

    }

    @SuppressWarnings("serial")
    private static class MethodNotAllowed extends OAuth2Exception {

        public MethodNotAllowed(String msg, Throwable t) {
            super(msg, t);
        }

        @Override
        public String getOAuth2ErrorCode() {
            return "method_not_allowed";
        }

        @Override
        public int getHttpErrorCode() {
            return 405;
        }

    }
}

4. endpoints 配置自定义异常翻译器

在 MyAuthorizationServerConfiguration 中添加以下配置

    @Autowired
    MyWebResponseExceptionTranslator myWebResponseExceptionTranslator;
    // 端点配置
    @Override
    public void configure(AuthorizationServerEndpointsConfigurer endpoints) throws Exception {
        // 配置端点允许的请求方式
        endpoints.allowedTokenEndpointRequestMethods(HttpMethod.GET, HttpMethod.POST);
        // 配置认证管理器
        endpoints.authenticationManager(authenticationManager);
        // 自定义异常翻译器,用于处理OAuth2Exception
        endpoints.exceptionTranslator(myWebResponseExceptionTranslator);
    }

5. 测试

总结

以上简单的梳理了授权服务器异常机制及自定义异常信息,只是指出了实现思路和简单案例,如果有其他问题,可以自己分析源码,找出原因。