场景:限制请求后端接口的频率,例如1秒钟只能请求次数不能超过10次,通常的写法是:
1.先去从redis里面拿到当前请求次数
2.判断当前次数是否大于或等于限制次数
3.当前请求次数小于限制次数时进行自增
这三步在请求不是很密集的时候,程序执行很快,可能不会产生问题,如果两个请求几乎在同一时刻到来,我们第1步和第2步的判断是无法保证原子性的。
改进方式:使用redis的lua脚本,将"读取值、判断大小、自增"放到redis的一次操作中,redis底层所有的操作请求都是串行的,也就是一个请求执行完,才会执行下一个请求。
自增的lua脚本如下
/**
* 自增过期时间的原子性脚本
*/
private String maxCountScriptText() {
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
}
?redis工具类:
package com.zhou.redis.util;
import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.exception.LockException;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
@Configuration
@Slf4j
public class RedisUtil {
public RedisTemplate<String, Object> redisTemplate;
private MaxCountScript maxCountScript;
private MaxCountQueryScript maxCountQueryScript;
public RedisUtil(RedisTemplate redisTemplate, MaxCountScript maxCountScript, MaxCountQueryScript maxCountQueryScript) {
this.redisTemplate = redisTemplate;
this.maxCountScript = maxCountScript;
this.maxCountQueryScript = maxCountQueryScript;
}
/**
* 尝试加锁,返回加锁成功或者失败
* @param time 秒
**/
public boolean tryLock(String key,Object value,Long time){
if(time == null || time <= 0){
time = 30L;
}
Boolean b = redisTemplate.opsForValue().setIfAbsent(key, value, Duration.ofSeconds(time));
return b == null ? false : b;
}
/**
* 释放锁(拿到锁之后才能调用释放锁)
**/
public boolean unLock(String key){
Boolean b = redisTemplate.delete(key);
return b == null ? false : b;
}
/**
* 对key进行自增1
* @param maxCount 最大值
* @param time 增加次数
* @return 自增后的值
*/
public Long incr(String key,int maxCount, int time){
List<String> keys = Collections.singletonList(key);
return redisTemplate.execute(maxCountScript, keys, maxCount, time);
}
/**
* 获得当前值
*/
public Long incrNow(String key){
List<String> keys = Collections.singletonList(key);
return redisTemplate.execute(maxCountQueryScript, keys);
}
}
?redis配置类:
package com.zhou.redis.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhou.redis.listener.MyRedisListener;
import com.zhou.redis.script.MaxCountQueryScript;
import com.zhou.redis.script.MaxCountScript;
import com.zhou.redis.util.RedisTopic;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import java.util.Arrays;
import java.util.List;
@Configuration
public class RedisConfig {
@SuppressWarnings("all")
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(factory);
//Json序列化配置
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper om = new ObjectMapper();
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
jackson2JsonRedisSerializer.setObjectMapper(om);
//String的序列化
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
//key采用string的序列化
template.setKeySerializer(stringRedisSerializer);
//hash的key采用string的序列化
template.setHashKeySerializer(stringRedisSerializer);
//value序列化采用jackson
template.setValueSerializer(jackson2JsonRedisSerializer);
//hash的value序列化方式采用jackson
template.setHashValueSerializer(jackson2JsonRedisSerializer);
template.afterPropertiesSet();
return template;
}
/**
* Redis消息监听器容器
* 这个容器加载了RedisConnectionFactory和消息监听器
* 可以添加多个监听不同话题的redis监听器,只需要把消息监听器和相应的消息订阅处理器绑定,该消息监听器
* 通过反射技术调用消息订阅处理器的相关方法进行一些业务处理
*
* @param redisConnectionFactory 连接工厂
* @param adapter 适配器
* @return redis消息监听容器
*/
@Bean
@SuppressWarnings("all")
public RedisMessageListenerContainer container(RedisConnectionFactory redisConnectionFactory,
FuncUpdateListener listener,
MessageListenerAdapter adapter) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
// 监听所有库的key过期事件
container.setConnectionFactory(redisConnectionFactory);
// 所有的订阅消息,都需要在这里进行注册绑定,new PatternTopic(TOPIC_NAME1)表示发布的主题信息
// 可以添加多个 messageListener,配置不同的通道
List<Topic> topicList = Arrays.asList(
new PatternTopic(RedisTopic.TOPIC1),
new PatternTopic(RedisTopic.TOPIC2)
);
container.addMessageListener(listener, topicList);
/**
* 设置序列化对象
* 特别注意:1. 发布的时候需要设置序列化;订阅方也需要设置序列化
* 2. 设置序列化对象必须放在[加入消息监听器]这一步后面,否则会导致接收器接收不到消息
*/
Jackson2JsonRedisSerializer seria = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
seria.setObjectMapper(objectMapper);
container.setTopicSerializer(seria);
return container;
}
/**
* 这个地方是给messageListenerAdapter 传入一个消息接受的处理器,利用反射的方法调用“receiveMessage”
* 也有好几个重载方法,这边默认调用处理器的方法 叫OnMessage
*/
@SuppressWarnings("all")
@Bean
public MessageListenerAdapter listenerAdapter() {
//MessageListenerAdapter receiveMessage = new MessageListenerAdapter(printMessageReceiver, "receiveMessage");
MessageListenerAdapter receiveMessage = new MessageListenerAdapter();
Jackson2JsonRedisSerializer seria = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
seria.setObjectMapper(objectMapper);
receiveMessage.setSerializer(seria);
return receiveMessage;
}
@Bean
public MaxCountScript maxCountScript() {
return new MaxCountScript(maxCountScriptText());
}
@Bean
public MaxCountQueryScript maxCountQueryScript() {
return new MaxCountQueryScript(maxCountQueryScriptText());
}
/**
* 自增过期时间的原子性脚本
*/
private String maxCountScriptText() {
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
/*return "local limitMaxCount = tonumber(ARGV[1])\n" +
"local limitSecond = tonumber(ARGV[2])\n" +
"local num = tonumber(redis.call('get', KEYS[1]) or '-1')\n" +
"if limitMaxCount then\n" +
" return -1\n" +
"end\n" +
"if num == -1 then\n" +
" redis.call('incr', KEYS[1])\n" +
" redis.call('expire', KEYS[1], limitSecond)\n" +
" return 1\n" +
"else\n" +
" if num >= limitMaxCount then\n" +
" return 0\n" +
" else\n" +
" redis.call('incr', KEYS[1])\n" +
" return 1\n" +
" end\n" +
"end";*/
}
/**
* 查询当前值脚本
*/
private String maxCountQueryScriptText() {
return "local key = KEYS[1]\n" +
"local current = redis.call('get', key);\n" +
"if current then\n" +
" return tonumber(current);\n" +
"else\n" +
" return current\n" +
"end\n";
}
}
?拦截模式枚举类:根据ip拦截或者方法拦截
package com.zhou.aop;
/**
* @author lang.zhou
* @since 2023/1/31 17:56
*/
public enum LimitType {
IP,DEFAULT
}
?封装自定义注解:@RateLimiter
package com.zhou.aop;
import java.lang.annotation.*;
/**
* @author lang.zhou
* @since 2023/1/31 17:49
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流key
*/
String key() default "RateLimiter";
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
/**
* 限流后返回的文字
*/
String limitMsg() default "访问过于频繁,请稍候再试";
}
?注解的切面逻辑:
package com.zhou.aop;
import com.zhou.redis.util.RedisUtil;
import com.zhou.common.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
/**
* 接口限流切面
* @author lang.zhou
* @since 2023/1/31 17:50
*/
@Aspect
@Slf4j
@Component
public class RateLimiterAspect {
@Autowired
private RedisUtil redisUtils;
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, point);
try {
Long number = redisUtils.incr(combineKey, count, time);
if (number == null || number.intValue() > count){
log.info("请求【{}】被拦截,{}秒内请求次数{}",combineKey,time,number);
throw new RuntimeException(rateLimiter.limitMsg());
}
} catch (ServiceRuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("网络繁忙,请稍候再试");
}
}
/**
* 获取限流key
*/
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuilder s = new StringBuilder(rateLimiter.key());
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if(requestAttributes != null){
HttpServletRequest request = requestAttributes.getRequest();
if (rateLimiter.limitType() == LimitType.IP) {
s.append(IpUtil.getIpAddr(request)).append("-");
}
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
s.append(targetClass.getName()).append(".").append(method.getName());
return s.toString();
}
}
?lua自增脚本类:
package com.zhou.redis.script;
import org.springframework.data.redis.core.script.DefaultRedisScript;
/**
* @author lang.zhou
* @since 2023/2/25
*/
public class MaxCountScript extends DefaultRedisScript<Long> {
public MaxCountScript(String script) {
super(script,Long.class);
}
}
?lua查询当前值的脚本类:
package com.zhou.redis.script;
import org.springframework.data.redis.core.script.DefaultRedisScript;
/**
* @author lang.zhou
* @since 2023/2/25
*/
public class MaxCountQueryScript extends DefaultRedisScript<Long> {
public MaxCountQueryScript(String script) {
super(script,Long.class);
}
}
?订阅消息通道的枚举:
package com.zhou.redis.util;
public class RedisTopic {
public static final String TOPIC1 = "TOPIC1";
public static final String TOPIC2 = "TOPIC2";
}
消息实体类:?
package com.zhou.redis.dto;
import lombok.Data;
import java.io.Serializable;
/**
* redis订阅消息实体
* @since 2022/11/11 17:34
*/
@Data
public class MyRedisMessage implements Serializable {
private String msg;
}
订阅消息监听器:?
package com.zhou.redis.listener;
import com.zhou.redis.dto.MyRedisMessage;
import com.zhou.redis.util.RedisTopic;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import javax.script.ScriptException;
/**
* @author lang.zhou
*/
@Slf4j
@Component
public class MyRedisListener implements MessageListener {
@Autowired
private RedisTemplate<String,Object> redisTemplate;
@Override
public void onMessage(Message message, byte[] pattern) {
String topic = new String(pattern);
// 接收的topic
log.info("channel:{}" , topic);
if(RedisTopic.TOPIC1.equals(topic)){
//
}else if(RedisTopic.TOPIC2.equals(topic)){
//序列化对象(特别注意:发布的时候需要设置序列化;订阅方也需要设置序列化)
MyRedisMessage msg = (MyRedisMessage) redisTemplate.getValueSerializer().deserialize(message.getBody());
log.info("message:{}",msg);
}
}
}
注解使用方式:1秒内一个ip最多只能请求10次
@RestController
@RequestMapping("/test/api")
public class CheckController{
@PostMapping("/limit")
@RateLimiter(time = 1, count = 10, limitType = LimitType.IP, limitMsg = "请求过于频繁,请稍后重试")
public void limit(HttpServletRequest request){
//执行业务代码
}
}