Skip to content

限流

java
package com.oneboi.springboot3.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
    /**
     * 限流阈值,每秒最大请求数
     * 默认值为100
     */
    int limit() default 100;

    /**
     * 限流提示消息
     * 默认提示"请求过于频繁,请稍后重试"
     */
    String message() default "请求过于频繁,请稍后重试";

    /**
     * 限流错误码
     * 默认错误码为"TOO_MANY_REQUESTS"
     */
    String code() default "TOO_MANY_REQUESTS";

}
java
package com.oneboi.springboot3.aspect;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

public class ApiRateLimiter {
    private final Semaphore semaphore;
    private final int maxPermits;
    private final ScheduledExecutorService scheduler;

    public ApiRateLimiter(int maxPermits) {
        this.maxPermits = maxPermits;
        this.semaphore = new Semaphore(maxPermits);
        this.scheduler = Executors.newScheduledThreadPool(1);

        // 每秒恢复所有许可证
        scheduler.scheduleAtFixedRate(() -> {
            try {
                int used = maxPermits - semaphoreavailablePermits();
                if (used > 0) {
                    semaphore.release(used);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }, 1, 1, TimeUnit.SECONDS);
    }

    public ApiRateLimiter() {
        this(100); // 默认100个请求每秒
    }


    private int semaphoreavailablePermits() {
        return semaphore.availablePermits();
    }


    public boolean tryProcessRequest() {
        return semaphore.tryAcquire();
    }
    public int getAvailablePermits() {
        return semaphore.availablePermits();
    }

    public int getMaxPermits() {
        return maxPermits;
    }


    public void shutdown() {
        scheduler.shutdown();
    }
}
java
package com.oneboi.springboot3.aspect;


import com.oneboi.springboot3.annotation.RateLimit;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;

@Component
@Aspect
public class RateLimitAspect {

    // 存储每个方法的限流管理器
    private final Map<String, ApiRateLimiter> rateLimiters = new ConcurrentHashMap<>();

    @Around("@annotation(com.oneboi.springboot3.annotation.RateLimit)")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        // 获取方法签名
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();

        // 获取限流注解
        RateLimit rateLimit = method.getAnnotation(RateLimit.class);

        // 生成限流管理器的key
        String key = method.getDeclaringClass().getName() + "." + method.getName();

        // 获取或创建限流管理器
        ApiRateLimiter rateLimiter = rateLimiters.computeIfAbsent(key,
                k -> new ApiRateLimiter(rateLimit.limit()));

        // 尝试获取许可证
        if (rateLimiter.tryProcessRequest()) {
            try {
                // 执行原方法
                return joinPoint.proceed();
            } catch (Exception e) {
                // 重新抛出异常
                throw e;
            }
        } else {
            // 限流处理,返回Map类型的错误响应
            Map<String, Object> response = new HashMap<>();
            response.put("status", "error");
            response.put("message", rateLimit.message());
            response.put("code", rateLimit.code());
            response.put("timestamp", System.currentTimeMillis());
            response.put("path", getRequestPath(joinPoint));
            response.put("limit", rateLimit.limit());
            response.put("available", rateLimiter.getAvailablePermits());
            response.put("statusCode", HttpStatus.TOO_MANY_REQUESTS.value());

            return response;
        }
    }

    /**
     * 获取请求路径
     *
     * @param joinPoint 连接点
     * @return 请求路径
     */
    private String getRequestPath(ProceedingJoinPoint joinPoint) {
        try {
            // 获取方法名作为路径标识
            MethodSignature signature = (MethodSignature) joinPoint.getSignature();
            String className = signature.getDeclaringType().getSimpleName();
            String methodName = signature.getName();
            return "/" + className + "/" + methodName;
        } catch (Exception e) {
            return "/unknown";
        }
    }
}
java
package com.oneboi.springboot3.controller;

import com.oneboi.springboot3.annotation.RateLimit;
import com.oneboi.springboot3.aspect.ApiRateLimiter;
import org.springframework.web.bind.annotation.*;

import java.util.HashMap;
import java.util.Map;


    /**
     * API控制器
     * 演示限流功能的使用
     */
    @RestController
    @RequestMapping("/api")
    public class ApiController {

        // 直接使用ApiRateLimiter的示例
        private final ApiRateLimiter directRateLimiter = new ApiRateLimiter(30);

        /**
         * 使用@RateLimit注解的限流示例
         * 每秒最多处理50个请求
         */
        @RateLimit(limit = 1, message = "普通API调用频率过高,请稍后再试")
        @GetMapping("/test2")
        public Map<String, Object> testApi() {
            Map<String, Object> response = new HashMap<>();
            response.put("status", "success");
            response.put("message", "API调用成功");
            response.put("data", "Hello World");
            response.put("timestamp", System.currentTimeMillis());

            return response;
        }

        /**
         * 敏感数据接口,限流更严格
         * 每秒最多处理20个请求
         */
        @RateLimit(limit = 20, message = "敏感数据访问频率过高,请稍后再试")
        @GetMapping("/sensitive")
        public Map<String, Object> getSensitiveData() {
            Map<String, Object> response = new HashMap<>();
            response.put("status", "success");
            response.put("message", "敏感数据访问成功");
            response.put("data", "这是受保护的敏感信息");
            response.put("timestamp", System.currentTimeMillis());

            return response;
        }

        /**
         * 写入操作,限流更严格
         * 每秒最多处理10个请求
         */
        @RateLimit(limit = 10, message = "写入操作频率过高,请稍后再试")
        @PostMapping("/create")
        public Map<String, Object> createResource(@RequestBody Map<String, Object> requestBody) {
            Map<String, Object> response = new HashMap<>();
            response.put("status", "success");
            response.put("message", "资源创建成功");
            response.put("data", requestBody);
            response.put("timestamp", System.currentTimeMillis());

            return response;
        }

        /**
         * 直接使用ApiRateLimiter的示例
         * 每秒最多处理30个请求
         */
        @GetMapping("/direct")
        public Map<String, Object> directLimit() {
            if (directRateLimiter.tryProcessRequest()) {
                Map<String, Object> response = new HashMap<>();
                response.put("status", "success");
                response.put("message", "直接限流API调用成功");
                response.put("data", "Direct Rate Limiter");
                response.put("available", directRateLimiter.getAvailablePermits());
                response.put("timestamp", System.currentTimeMillis());

                return response;
            } else {
                Map<String, Object> response = new HashMap<>();
                response.put("status", "error");
                response.put("message", "请求过于频繁,请稍后重试");
                response.put("code", "TOO_MANY_REQUESTS");
                response.put("available", directRateLimiter.getAvailablePermits());
                response.put("timestamp", System.currentTimeMillis());

                return response;
            }
        }
        /**
         * 状态检查接口,不限流
         */
        @GetMapping("/status")
        public Map<String, Object> getStatus() {
            Map<String, Object> response = new HashMap<>();
            response.put("status", "ok");
            response.put("message", "服务运行正常");
            response.put("timestamp", System.currentTimeMillis());

            return response;
        }
}
java
package com.oneboi.springboot3.exception;

import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.context.request.WebRequest;

import java.util.HashMap;
import java.util.Map;
@ControllerAdvice
public class GlobalExceptionHandler {

    /**
     * 处理限流异常
     *
     * @param ex 限流异常
     * @param request 请求对象
     * @return 错误响应
     */
    @ExceptionHandler(RateLimitException.class)
    public ResponseEntity<?> handleRateLimitException(RateLimitException ex, WebRequest request) {
        Map<String, Object> response = new HashMap<>();
        response.put("status", "error");
        response.put("message", ex.getMessage());
        response.put("code", ex.getCode());
        response.put("timestamp", System.currentTimeMillis());
        response.put("path", request.getDescription(false).replace("uri=", ""));

        return ResponseEntity.status(ex.getStatusCode()).body(response);
    }

    /**
     * 处理其他异常
     *
     * @param ex 异常
     * @param request 请求对象
     * @return 错误响应
     */
    @ExceptionHandler(Exception.class)
    public ResponseEntity<?> handleGlobalException(Exception ex, WebRequest request) {
        Map<String, Object> response = new HashMap<>();
        response.put("status", "error");
        response.put("message", "服务器内部错误");
        response.put("error", ex.getMessage());
        response.put("timestamp", System.currentTimeMillis());
        response.put("path", request.getDescription(false).replace("uri=", ""));

        return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(response);
    }
}
java
package com.oneboi.springboot3.exception;

/**
 * 限流异常
 * 用于标识限流触发的异常
 */
public class RateLimitException extends RuntimeException {

    private final String code;
    private final int statusCode;

    public RateLimitException(String message, String code, int statusCode) {
        super(message);
        this.code = code;
        this.statusCode = statusCode;
    }

    public String getCode() {
        return code;
    }

    public int getStatusCode() {
        return statusCode;
    }
}

.