02月14, 2021

12、SpringBoot 与 Shiro 整合 (四)

SpringBoot与Shiro整合(四)

在使用 Shiro 的过程中,遇到一个痛点,就是对 restful 支持不太好 restful风格,一般包含两大块 请求 GET,POST,PUT,DELETE,PUTCH URI地址 http://127.0.0.1:8080/user?id=1&username=jack http://127.0.0.1:8080/user/id/1/username/jack

分析

首先先回顾下 Shiro 的过滤器链,一般我们都有如下配置:

Map<String,String> map = new LinkedHashMap<>();
map.put("/index","anon");
map.put("/login","anon");

map.put("/users","perms[user:list]");

map.put("/**","authc");

换成xml的配置可能更加直观

/index = anon
/login = anon
/users = perms[user:list]
/** = authc

其中 /users 请求对应到 perms 过滤器,对应的类: org.apache.shiro.web.filter.authz.PermissionsAuthorizationFilter,其中的 onAccessDenied 方法是在没有权限时被调用的, 在其父类AuthorizationFilter中源码如下:

protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws IOException {

    Subject subject = this.getSubject(request, response);
    // 如果未登录, 则重定向到配置的 loginUrl
    if (subject.getPrincipal() == null) {
        this.saveRequestAndRedirectToLogin(request, response);
    } else {
        // 如果当前用户没有权限, 则跳转到 UnauthorizedUrl
        // 如果没有配置 UnauthorizedUrl, 则返回 401 状态码.
        String unauthorizedUrl = this.getUnauthorizedUrl();
        if (StringUtils.hasText(unauthorizedUrl)) {
            WebUtils.issueRedirect(request, response, unauthorizedUrl);
        } else {
            WebUtils.toHttp(response).sendError(401);
        }
    }

    return false;
}

我们可以在这里可以判断当前请求是否时 AJAX 请求,如果是,则不跳转到 logoUrl 或 UnauthorizedUrl 页面,而是返回 JSON 数据。

还有一个方法是 pathsMatch,是将当前请求的 url 与所有配置的 perms 过滤器链进行匹配,是则进行权限检查,不是则接着与下一个过滤器链进行匹配,源码如下:

protected boolean pathsMatch(String path, ServletRequest request) {
    String requestURI = this.getPathWithinApplication(request);
    if (requestURI != null && !"/".equals(requestURI) && requestURI.endsWith("/")) {
        requestURI = requestURI.substring(0, requestURI.length() - 1);
    }

    if (path != null && !"/".equals(path) && path.endsWith("/")) {
        path = path.substring(0, path.length() - 1);
    }

    log.trace("Attempting to match pattern '{}' with current requestURI '{}'...", path, Encode.forHtml(requestURI));
    return this.pathsMatch(path, requestURI);
}

方法

了解完这两个方法,我来说说如何利用这两个方法来实现功能。

我们可以从配置的过滤器链来入手,原先的配置如:

/users = perms[user:list]

我们可以改为 /user==GET/user==POST 方式。== 用来分隔, 后面的部分指 HTTP Method

使用这种方式还要注意一个方法,即:

org.apache.shiro.web.filter.mgt.PathMatchingFilterChainResolver 中的 getChain 方法,用来获取当前请求的 URL 应该使用的过滤器,源码如下:

public FilterChain getChain(ServletRequest request, ServletResponse response, FilterChain originalChain) {
    // 1. 判断有没有配置过滤器链, 没有一个过滤器都没有则直接返回 null
    FilterChainManager filterChainManager = this.getFilterChainManager();
    if (!filterChainManager.hasChains()) {
        return null;
    } else {
        // 2. 获取当前请求的 URI
        String requestURI = this.getPathWithinApplication(request);
        if (requestURI != null && !"/".equals(requestURI) && requestURI.endsWith("/")) {
            requestURI = requestURI.substring(0, requestURI.length() - 1);
        }

        Iterator var6 = filterChainManager.getChainNames().iterator();

        String pathPattern;
        // 3. 遍历所有的过滤器链
        do {
            if (!var6.hasNext()) {
                return null;
            }

            pathPattern = (String)var6.next();
            if (pathPattern != null && !"/".equals(pathPattern) && pathPattern.endsWith("/")) {
                pathPattern = pathPattern.substring(0, pathPattern.length() - 1);
            }
            // 4. 判断当前请求的 URL 与过滤器链中的 URL 是否匹配.
        } while(!this.pathMatches(pathPattern, requestURI));

        if (log.isTraceEnabled()) {
            log.trace("Matched path pattern [" + pathPattern + "] for requestURI [" + Encode.forHtml(requestURI) + "].  Utilizing corresponding filter chain...");
        }

        // 5. 如果路径匹配, 则获取其实现类.(如 perms[user:list] 或 perms[user:delete] 都返回 perms)
        // 具体对  perms[user:list] 或 perms[user:delete] 的判断是在上面讲到的 PermissionsAuthorizationFilter 的 pathsMatch 方法中.
        return filterChainManager.proxy(originalChain, pathPattern);
    }
}

这里大家需要注意,第四步的判断,我们已经将过滤器链,也就是这里的 pathPattern 改为了 /xxx==GET 这种方式,而请求的 URL 却仅包含 /xxx,那么这里的 pathMatches 方法是肯定无法匹配成功,所以我们需要在第四步判断的时候,只判断前面的 URL 部分。

整个过程如下:

  1. 在过滤器链上对 restful 请求配置需要的HTTP Method,如:/user==DELETE
  2. 修改 PathMatchingFilterChainResolvergetChain 方法,当前请求的 URL 与过滤器链匹配时,过滤器只取 URL 部分进行判断。
  3. 修改过滤器的 pathsMatch 方法,判断当前请求的 URL 与请求方式是否与过滤器链中配置的一致。
  4. 修改过滤器的 onAccessDenied 方法,当访问被拒绝时,根据普通请求和 AJAX 请求分别返回 HTMLJSON 数据。

实现

1、过滤器链添加 http method

一般情况下,过滤器链都需要从数据库中读取出来,所以,一般是写在Service中,可以专门创建ShiroService来处理过滤器链的问题

public class ShiroService {

    @Autowired
    private IMenuService menuService;

    @Autowired
    private IOperatorService operatorService;

    /**
     * 从数据库加载用户拥有的菜单权限和 API 权限.
     */
    public Map<String, String> getUrlPermsMap() {
        Map<String, String> filterChainDefinitionMap = new LinkedHashMap<>();

        // 系统默认过滤器
        filterChainDefinitionMap.put("/favicon.ico", "anon");
        filterChainDefinitionMap.put("/css/**", "anon");
        filterChainDefinitionMap.put("/fonts/**", "anon");
        filterChainDefinitionMap.put("/images/**", "anon");
        filterChainDefinitionMap.put("/js/**", "anon");
        filterChainDefinitionMap.put("/lib/**", "anon");
        filterChainDefinitionMap.put("/active/**", "anon");
        filterChainDefinitionMap.put("/login", "anon");
        filterChainDefinitionMap.put("/register", "anon");
        filterChainDefinitionMap.put("/403", "anon");
        filterChainDefinitionMap.put("/404", "anon");
        filterChainDefinitionMap.put("/500", "anon");
        filterChainDefinitionMap.put("/error", "anon");

        // 获取所有子集菜单
        List<Menu> menuList = menuService.getLeafNodeMenu();
        for (Menu menu : menuList) {
            String url = menu.getUrl();
            if (url != null) {
                String perms = "perms[" + menu.getPerms() + "]";
                //讲子集菜单放入到过滤中
                filterChainDefinitionMap.put(url, perms);
            }
        }

        // 获取所有操作
        List<Operator> operatorList = operatorService.list();//mybatis-plus自带list方法
        for (Operator operator : operatorList) {
            String url = operator.getUrl();
            if (url != null) {
                if (operator.getHttpMethod() != null
                        && !"".equals(operator.getHttpMethod())) {
                    url += ("==" + operator.getHttpMethod());
                }
                String perms = "perms[" + operator.getPerms() + "]";
                filterChainDefinitionMap.put(url, perms);
            }
        }

        filterChainDefinitionMap.put("/**", "authc");

        return filterChainDefinitionMap;
    }
}

如: /xxx==GET = perms[user:list]这里的 getUrlgetHttpMethodgetPerms 分别对应 /xxxGETuser:list

2、修改 PathMatchingFilterChainResolver 的 getChain 方法

由于 Shiro 没有提供相应的接口,且我们不能直接修改源码,所以我们需要新建一个类继承 PathMatchingFilterChainResolver 并重写 getChain 方法,然后替换掉 PathMatchingFilterChainResolver 即可。

public class RestPathMatchingFilterChainResolver extends PathMatchingFilterChainResolver {
    private static final Logger log = LoggerFactory.getLogger(RestPathMatchingFilterChainResolver.class);
    @Override
    public FilterChain getChain(ServletRequest request, ServletResponse response, FilterChain originalChain) {
        FilterChainManager filterChainManager = this.getFilterChainManager();
        if (!filterChainManager.hasChains()) {
            return null;
        } else {
            String requestURI = this.getPathWithinApplication(request);
            if (requestURI != null && !"/".equals(requestURI) && requestURI.endsWith("/")) {
                requestURI = requestURI.substring(0, requestURI.length() - 1);
            }

            //the 'chain names' in this implementation are actually path patterns defined by the user.  We just use them
            //as the chain name for the FilterChainManager's requirements
            for (String pathPattern : filterChainManager.getChainNames()) {
                String[] pathPatternArray = pathPattern.split("==");
                // 只用过滤器链的 URL 部分与请求的 URL 进行匹配
                if (pathMatches(pathPatternArray[0], requestURI)) {
                    if (log.isTraceEnabled()) {
                        log.trace("Matched path pattern [" + pathPattern + "] for requestURI [" + requestURI + "].  " +
                                "Utilizing corresponding filter chain...");
                    }
                    return filterChainManager.proxy(originalChain, pathPattern);
                }
            }
            return null;
        }
    }
}

然后替换掉 PathMatchingFilterChainResolver,它是在 ShiroFilterFactoryBeancreateInstance 方法里初始化的。

下面是ShiroFilterFactoryBean中,createInstance的源码

protected AbstractShiroFilter createInstance() throws Exception {
    log.debug("Creating Shiro Filter instance.");
    SecurityManager securityManager = this.getSecurityManager();
    String msg;
    if (securityManager == null) {
        msg = "SecurityManager property must be set.";
        throw new BeanInitializationException(msg);
    } else if (!(securityManager instanceof WebSecurityManager)) {
        msg = "The security manager does not implement the WebSecurityManager interface.";
        throw new BeanInitializationException(msg);
    } else {
        FilterChainManager manager = this.createFilterChainManager();
        PathMatchingFilterChainResolver chainResolver = new PathMatchingFilterChainResolver();
        chainResolver.setFilterChainManager(manager);
        return new ShiroFilterFactoryBean.SpringShiroFilter((WebSecurityManager)securityManager, chainResolver);
    }
}

所以同样的套路,继承 ShiroFilterFactoryBean 并重写 createInstance 方法,将 new PathMatchingFilterChainResolver(); 改为 new RestPathMatchingFilterChainResolver(); 即可。

public class RestShiroFilterFactoryBean extends ShiroFilterFactoryBean {
    private static final Logger log = LoggerFactory.getLogger(RestShiroFilterFactoryBean.class);

    @Override
    protected AbstractShiroFilter createInstance() {

        log.debug("Creating Shiro Filter instance.");
        SecurityManager securityManager = getSecurityManager();
        if (securityManager == null) {
            String msg = "SecurityManager property must be set.";
            throw new BeanInitializationException(msg);
        }

        if (!(securityManager instanceof WebSecurityManager)) {
            String msg = "The security manager does not implement the WebSecurityManager interface.";
            throw new BeanInitializationException(msg);
        }

        FilterChainManager manager = createFilterChainManager();

        //Expose the constructed FilterChainManager by first wrapping it in a
        // FilterChainResolver implementation. The AbstractShiroFilter implementations
        // do not know about FilterChainManagers - only resolvers:
        PathMatchingFilterChainResolver chainResolver = new RestPathMatchingFilterChainResolver();
        chainResolver.setFilterChainManager(manager);

        //Now create a concrete ShiroFilter instance and apply the acquired SecurityManager and built
        //FilterChainResolver.  It doesn't matter that the instance is an anonymous inner class
        //here - we're just using it because it is a concrete AbstractShiroFilter instance that accepts
        //injection of the SecurityManager and FilterChainResolver:
        return new SpringShiroFilter((WebSecurityManager) securityManager, chainResolver);
    }

    private static final class SpringShiroFilter extends AbstractShiroFilter {
        protected SpringShiroFilter(WebSecurityManager webSecurityManager, FilterChainResolver resolver) {
            super();
            if (webSecurityManager == null) {
                throw new IllegalArgumentException("WebSecurityManager property cannot be null.");
            }
            setSecurityManager(webSecurityManager);
            if (resolver != null) {
                setFilterChainResolver(resolver);
            }
        }
    }
}

最后记得将 ShiroFilterFactoryBean 改为 RestShiroFilterFactoryBean

@Bean
public ShiroFilterFactoryBean getFilterFactoryBean(DefaultWebSecurityManager securityManager) {
    ShiroFilterFactoryBean shiroFilterFactoryBean = new RestShiroFilterFactoryBean();
    // 参数配置略
    return shiroFilterFactoryBean;
}

3、修改过滤器的 pathsMatch 方法和 onAccessDenied 方法

同样新建一个类继承原有的 PermissionsAuthorizationFilter 并重写 pathsMatch 。同时,还需要重写过滤器的 onAccessDenied 方法。

public class RestAuthorizationFilter extends PermissionsAuthorizationFilter {
    private static final Logger log = LoggerFactory
            .getLogger(RestAuthorizationFilter.class);
    @Override
    protected boolean pathsMatch(String path, ServletRequest request) {
        String requestURI = this.getPathWithinApplication(request);
        String[] strings = path.split("==");
        if (strings.length <= 1) {
            // 普通的 URL, 正常处理
            return this.pathsMatch(strings[0], requestURI);
        } else {
            // 获取当前请求的 http method.
            String httpMethod = WebUtils.toHttp(request).getMethod().toUpperCase();
            // 匹配当前请求的 http method 与 过滤器链中的的是否一致
            return httpMethod.equals(strings[1].toUpperCase()) && this.pathsMatch(strings[0], requestURI);
        }
    }

    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws IOException {
        Subject subject = getSubject(request, response);
        // 如果未登录
        if (subject.getPrincipal() == null) {
            // AJAX 请求返回 JSON
            if (WebUtil.isAjaxRequest(WebUtils.toHttp(request))) {
                if (log.isDebugEnabled()) {
                    log.debug("用户: [{}] 请求 restful url : {}, 未登录被拦截.", subject.getPrincipal(), this.getPathWithinApplication(request));                }
                WebUtil.writeJson(ResultVO.fail(ResultCode.USER_AUTHENTICATION_ERROR),response);
            } else {
                // 其他请求跳转到登陆页面
                saveRequestAndRedirectToLogin(request, response);
            }
        } else {
            // 如果已登陆, 但没有权限
            // 对于 AJAX 请求返回 JSON
            if (WebUtil.isAjaxRequest(WebUtils.toHttp(request))) {
                if (log.isDebugEnabled()) {
                    log.debug("用户: [{}] 请求 restful url : {}, 无权限被拦截.", subject.getPrincipal(), this.getPathWithinApplication(request));
                }

                WebUtil.writeJson(ResultVO.fail(ResultCode.USER_AUTHORIZATION_ERROR),response);
            } else {
                // 对于普通请求, 跳转到配置的 UnauthorizedUrl 页面.
                // 如果未设置 UnauthorizedUrl, 则返回 401 状态码
                String unauthorizedUrl = getUnauthorizedUrl();
                if (StringUtils.hasText(unauthorizedUrl)) {
                    WebUtils.issueRedirect(request, response, unauthorizedUrl);
                } else {
                    WebUtils.toHttp(response).sendError(HttpServletResponse.SC_UNAUTHORIZED);
                }
            }

        }
        return false;
    }
}

注意,WebUtil是自己写的一个工具帮助类,代码如下:

public class WebUtil {
    /**
     * 是否是Ajax请求
     */
    public static boolean isAjaxRequest(HttpServletRequest request) {
        String requestedWith = request.getHeader("x-requested-with");
        return "XMLHttpRequest".equalsIgnoreCase(requestedWith);
    }

    /**
     * 输出JSON
     */
    public static void writeJson(Object object, ServletResponse response) {
        PrintWriter out = null;
        try {
            response.setCharacterEncoding("UTF-8");
            response.setContentType("application/json; charset=utf-8");
            out = response.getWriter();
            ObjectMapper objectMapper = new ObjectMapper();
            out.write(objectMapper.writeValueAsString(object));
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (out != null) {
                out.close();
            }
        }
    }

    public static void redirectUrl(String redirectUrl) {
        HttpServletResponse response = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getResponse();
        try {
            response.sendRedirect(redirectUrl);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 获取当前请求的 Http Method
     * @return
     */
    public static String getRequestHTTPMethod() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        return request.getMethod();
    }
}

重写完 pathsMatchonAccessDenied 方法后,将这个类替换原有的 perms 过滤器的类:

@Bean
public ShiroFilterFactoryBean getFilterFactoryBean(DefaultWebSecurityManager securityManager) {
    ShiroFilterFactoryBean shiroFilterFactoryBean = new RestShiroFilterFactoryBean();
    Map<String, Filter> filters = shiroFilterFactoryBean.getFilters();
    filters.put("perms", new RestAuthorizationFilter());
    // 参数配置略
    return shiroFilterFactoryBean;
}

这里只改了 perms 过滤器,对于其他过滤器也是同样的道理,重写过滤器的 pathsMatchonAccessDenied 方法,并覆盖原有过滤器即可。比如,我们完全还可以将authc 对应的 FormAuthenticationFilter 也改成支持restful方式的

直接创建 RestFormAuthenticationFilter

/**
 * 修改后的 authc 过滤器, 添加对 AJAX 请求的支持.
 */
public class RestFormAuthenticationFilter extends FormAuthenticationFilter {

    private static final Logger log = LoggerFactory
            .getLogger(RestFormAuthenticationFilter.class);

    @Override
    protected boolean pathsMatch(String path, ServletRequest request) {
        boolean flag;
        String requestURI = this.getPathWithinApplication(request);

        String[] strings = path.split("==");

        if (strings.length <= 1) {
            // 普通的 URL, 正常处理
            flag = this.pathsMatch(strings[0], requestURI);
        } else {
            // 获取当前请求的 http method.
            String httpMethod = WebUtils.toHttp(request).getMethod().toUpperCase();
            // 匹配当前请求的 url 和 http method 与过滤器链中的的是否一致
            flag = httpMethod.equals(strings[1].toUpperCase()) && this.pathsMatch(strings[0], requestURI);
        }

        if (flag) {
            log.debug("URL : [{}] matching authc filter : [{}]", requestURI, path);
        }
        return flag;
    }

    /**
     * 当没有权限被拦截时:
     * 如果是 AJAX 请求, 则返回 JSON 数据.
     * 如果是普通请求, 则跳转到配置 UnauthorizedUrl 页面.
     */
    @Override
    protected boolean onAccessDenied(ServletRequest request,
                                     ServletResponse response) throws Exception {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        if (isLoginRequest(request, response)) {
            if (isLoginSubmission(request, response)) {
                if (log.isTraceEnabled()) {
                    log.trace("Login submission detected.  Attempting to execute login.");
                }
                return executeLogin(request, response);
            } else {
                if (log.isTraceEnabled()) {
                    log.trace("Login page view.");
                }
                //allow them to see the login page ;)
                return true;
            }
        } else {
            if (log.isTraceEnabled()) {
                log.trace("Attempting to access a path which requires authentication.  Forwarding to the " +
                        "Authentication url [" + getLoginUrl() + "]");
            }

            if (WebUtil.isAjaxRequest(WebUtils.toHttp(request))) {
                if (log.isDebugEnabled()) {
                    log.debug("sessionId: [{}], ip: [{}] 请求 restful url : {}, 未登录被拦截.", httpServletRequest.getRequestedSessionId(), IPUtils.getIpAddr(),
                            this.getPathWithinApplication(request));
                }

                WebUtil.writeJson(ResultVO.fail(ResultCode.USER_NOT_LOGGED_IN), response);
            } else {
                saveRequestAndRedirectToLogin(request, response);
            }
            return false;
        }
    }
}

这里面用到了自定义的工具类IPUtils,直接给上代码,喜欢的同学直接拿走

public class IPUtils {

    /**
     * 获取请求 IP (WEB 服务)
     *
     * @return IP 地址
     */
    public static String getIpAddr() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
            if (ip.equals("127.0.0.1")) {
                //根据网卡取本机配置的 IP
                InetAddress inet = null;
                try {
                    inet = InetAddress.getLocalHost();
                    ip = inet.getHostAddress();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        // 多个代理的情况,第一个 IP 为客户端真实 IP,多个 IP 按照','分割
        if (ip != null && ip.length() > 15) {
            if (ip.indexOf(",") > 0) {
                ip = ip.substring(0, ip.indexOf(","));
            }
        }
        return ip;
    }


    /**
     * 获取当前计算机 IP
     */
    public static String getHostIp() {
        try {
            return InetAddress.getLocalHost().getHostAddress();
        } catch (UnknownHostException e) {
            e.printStackTrace();
        }
        return "127.0.0.1";
    }


    /**
     * 获取当前计算机名称
     */
    public static String getHostName() {
        try {
            return InetAddress.getLocalHost().getHostName();
        } catch (UnknownHostException e) {
            e.printStackTrace();
        }
        return "未知";
    }
}

4、最终重建后RestShiroFilterFactoryBean的对象

ShiroConfig中,原来获取ShiroFilterFactoryBean对象的方法,需要经过完全的改写 ShiroConfig

......
@Bean
public ShiroFilterFactoryBean getFilterFactoryBean(DefaultWebSecurityManager securityManager){
    RestShiroFilterFactoryBean shiroFilterFactoryBean = new RestShiroFilterFactoryBean();
    shiroFilterFactoryBean.setSecurityManager(securityManager);
    shiroFilterFactoryBean.setLoginUrl("/login");
    shiroFilterFactoryBean.setUnauthorizedUrl("/403");
    Map<String, Filter> filters = shiroFilterFactoryBean.getFilters();
    filters.put("authc", new RestFormAuthenticationFilter());
    filters.put("perms", new RestAuthorizationFilter());

    Map<String, String> urlPermsMap = shiroService.getUrlPermsMap();

    System.out.println("=======================测试=========================");
    for (Map.Entry<String, String> set:
         urlPermsMap.entrySet()) {
        System.out.println("set.getKey() = " + set.getKey() + ", set.getValue() = " + set.getValue());
    }
    System.out.println("=======================测试=========================");

    shiroFilterFactoryBean.setFilterChainDefinitionMap(urlPermsMap);
    return shiroFilterFactoryBean;
}
......

本文链接:http://www.yanhongzhi.com/post/springboot-12.html

-- EOF --

Comments