package com.demo.config.websocket;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WebSocketConfig {
/**
* 注入ServerEndpointExporter,
* 这个bean会自动注册使用了@ServerEndpoint注解声明的Websocket endpoint
*/
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
package com.demo.config.websocket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.TimeUnit;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.ObjectUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.demo.config.redis.RedisUtil;
import jakarta.annotation.Resource;
import jakarta.websocket.OnClose;
import jakarta.websocket.OnError;
import jakarta.websocket.OnMessage;
import jakarta.websocket.OnOpen;
import jakarta.websocket.Session;
import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import lombok.extern.slf4j.Slf4j;
@Component
@Slf4j
@ServerEndpoint("/websocket/{userId}")
public class WebSocket {
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
/**
* 用户ID
*/
private String userId;
/**
* WebSocket是当前类名
*/
private static CopyOnWriteArraySet<WebSocket> webSockets = new CopyOnWriteArraySet<>();
/**
* 用来存在线连接用户信息
*/
private static ConcurrentHashMap<String, Session> sessionPool = new ConcurrentHashMap<String, Session>();
private static String WEB_SOCKET_KEY = "web_socket_key_";
/**
* 链接成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam(value = "userId") String userId) {
try {
this.session = session;
this.userId = userId;
webSockets.add(this);
sessionPool.put(userId, session);
// 缓存离线数据
cacheMessageContains(userId);
log.info("【websocket消息】有新的连接,总数为:" + webSockets.size());
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 链接关闭调用的方法
*/
@OnClose
public void onClose() {
try {
webSockets.remove(this);
sessionPool.remove(this.userId);
log.info("【websocket消息】连接断开,总数为:" + webSockets.size());
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 收到客户端消息后调用的方法
*
* @param message
*/
@OnMessage
public void onMessage(String message) {
log.info("【websocket消息】收到客户端消息:" + message);
}
/**
* 发送错误时的处理
*
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("用户错误,原因:" + error.getMessage());
error.printStackTrace();
}
// 此为广播消息
public static void sendAllMessage(String message) {
log.info("【websocket消息】广播消息:" + message);
for (WebSocket webSocket : webSockets) {
try {
if (webSocket.session.isOpen()) {
webSocket.session.getAsyncRemote().sendText(message);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
// 此为单点消息
@Async
public void sendOneMessage(String userId, String message) {
if (sessionPool.size() <= 0 || ObjectUtils.isEmpty(sessionPool.get(userId))) {
cacheMessagePut(userId, message);
} else {
Session session = sessionPool.get(userId);
if (session != null && session.isOpen()) {
try {
log.info("【websocket消息】 单点消息:" + message);
session.getAsyncRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
// 此为单点消息(多人)
public void sendMoreMessage(String[] userIds, String message) {
for (String userId : userIds) {
Session session = sessionPool.get(userId);
if (session != null && session.isOpen()) {
try {
log.info("【websocket消息】 单点消息:" + message);
session.getAsyncRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
/**
* 查询是否有离线消息并推送
*/
public void cacheMessageContains(String userId) {
String key = WEB_SOCKET_KEY + userId;
List<String> range = new ArrayList();
//是否有暂存的消息,如果有则发送消息
if (RedisUtil.redis.hasKey(key)) {
range = RedisUtil.redis.opsForList().range(key, 0, -1);
}
if (CollectionUtils.isNotEmpty(range) && range.size() > 0) {
range.forEach(msg -> {
sendOneMessage(userId, msg);
});
RedisUtil.redis.delete(key);
}
}
/**
* 暂存离线消息
*/
public void cacheMessagePut(String userId, String message) {
String key = WEB_SOCKET_KEY + userId;
if (StringUtils.isNotEmpty(message)) {
RedisUtil.redis.opsForList().rightPush(key, message);
log.info(WEB_SOCKET_KEY + userId + "消息暂存成功");
}
}
}
package com.demo.config.websocket;
import cn.hutool.json.JSONObject;
import jakarta.annotation.Resource;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
/**
* @class:WebSocketController
* @since: 2023/9/20 9:01
* @description:
*/
@RestController
@RequestMapping("/websocket")
public class WebSocketController {
@Resource
private WebSocket webSocket;
@RequestMapping("/send/{userId}")
public void send(@PathVariable String userId){
//全体发送
// WebSocket.sendAllMessage("消息来了3333333333333333");
//单个用户发送 (userId为用户id)
webSocket.sendOneMessage(userId, "消息来了2222222222222");
//多个用户发送 (userIds为多个用户id,逗号‘,’分隔)
// webSocket.sendMoreMessage(userIds, obj.toJSONString());
}
}
package com.demo.config.redis;
import com.alibaba.fastjson.JSON;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scheduling.annotation.Async;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
/**
* redis配置项
*/
@Slf4j
@Component
public class RedisUtil {
@Resource
private RedisTemplate redisTemplate;
//------------------websock------------------------
//2.添加静态的变量
public static RedisTemplate redis;
@PostConstruct
public void getRedisTemplate(){
redis=this.redisTemplate;
}
//------------------websock------------------------
}
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
tomcat/undertow/其他:如果存在多个会导致方法的实现类过多找不到对应的使用方法,websocket容器加载为null。
如果使用undertow容器则需要排除spring-boot-start-web自带的tomcat容器
如果仍然解决不了则需要排查是否存在其余功能引入了tomcat,最简单的则是单独起websocket