webflux的websocket连接与生命周期

1、配置入口:

  1. import com.mti.handler.MessageHandler;
  2. import org.springframework.beans.factory.annotation.Autowired;
  3. import org.springframework.context.annotation.Bean;
  4. import org.springframework.context.annotation.Configuration;
  5. import org.springframework.core.Ordered;
  6. import org.springframework.web.reactive.HandlerMapping;
  7. import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
  8. import org.springframework.web.reactive.socket.WebSocketHandler;
  9. import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter;
  10. import java.util.HashMap;
  11. import java.util.Map;
  12. /**
  13. * ThreadConfig class
  14. *
  15. * @author zhaoyj
  16. * @date 2019/3/12
  17. */
  18. @Configuration
  19. public class WebSocketConfiguration {
  20. @Autowired
  21. @Bean
  22. public HandlerMapping webSocketMapping(final MessageHandler echoHandler) {
  23. final Map<String, WebSocketHandler> map = new HashMap<>();
  24. map.put("/echo", echoHandler);
  25. final SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping();
  26. mapping.setOrder(Ordered.HIGHEST_PRECEDENCE);
  27. mapping.setUrlMap(map);
  28. return mapping;
  29. }
  30. @Bean
  31. public WebSocketHandlerAdapter handlerAdapter() {
  32. return new WebSocketHandlerAdapter();
  33. }
  34. }

2、配置Handler

  1. import com.alibaba.fastjson.JSONObject;
  2. import com.google.protobuf.InvalidProtocolBufferException;
  3. import com.mti.configuration.Systemconfig;
  4. import com.mti.enums.ReferenceMsgType;
  5. import com.mti.exception.BusinessException;
  6. import com.mti.handler.up.StreamReferenceReq;
  7. import com.mti.proto.Linkproto;
  8. import com.mti.vo.Message;
  9. import com.mti.websocket.SocketSessionRegistry;
  10. import lombok.extern.slf4j.Slf4j;
  11. import org.springframework.beans.factory.annotation.Autowired;
  12. import org.springframework.core.io.buffer.DataBufferFactory;
  13. import org.springframework.core.task.TaskExecutor;
  14. import org.springframework.stereotype.Component;
  15. import org.springframework.web.reactive.socket.WebSocketHandler;
  16. import org.springframework.web.reactive.socket.WebSocketMessage;
  17. import org.springframework.web.reactive.socket.WebSocketSession;
  18. import reactor.core.publisher.Flux;
  19. import reactor.core.publisher.Mono;
  20. import java.util.Optional;
  21. import java.util.concurrent.ScheduledThreadPoolExecutor;
  22. import java.util.concurrent.TimeUnit;
  23. /**
  24. * MessageHandler class
  25. *
  26. * @author zhaoyj
  27. * @date 2019/3/12
  28. */
  29. @Component
  30. @Slf4j
  31. public class MessageHandler implements WebSocketHandler {
  32. @Autowired
  33. private SocketSessionRegistry sessionRegistry;
  34. @Autowired
  35. private ScheduledThreadPoolExecutor executor;
  36. @Autowired
  37. private Systemconfig systemconfig;
  38. @Autowired
  39. DispatchFactory dispatchFactory;
  40. @Autowired
  41. TaskExecutor taskExecutor;
  42. @Override
  43. public Mono<Void> handle(WebSocketSession session) {
  44. return session.receive().doOnSubscribe(s -> {
  45. log.info("发起连接:{}",s);
  46. /**
  47. * 你有10秒时间登陆,不登陆就关掉连接;并且不给任何错误信息
  48. */
  49. if(systemconfig.getLoginInterval() != 0){
  50. executor.schedule(() -> sessionRegistry.checkAndRemove(session),systemconfig.getLoginInterval(),TimeUnit.SECONDS);
  51. }
  52. if(systemconfig.getPingInterval() != 0){
  53. executor.schedule(() -> session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
  54. }
  55. }).doOnTerminate(() -> {
  56. sessionRegistry.unregisterSession(session);
  57. StreamReferenceReq req = (StreamReferenceReq) dispatchFactory.getCommand(ReferenceMsgType.SEND_VALUE);
  58. taskExecutor.execute(() -> Optional.ofNullable(req.removeSession(session)).ifPresent(list -> list.forEach(req::sendStopStreamConfig)));
  59. log.info("doOnTerminate");
  60. }).doOnComplete(() -> {
  61. log.info("doOnComplete");
  62. }).doOnCancel(() -> {
  63. log.info("doOnCancel");
  64. }).doOnNext(message -> {
  65. if(message.getType().equals(WebSocketMessage.Type.BINARY)){
  66. log.info("收到二进制消息");
  67. Linkproto.LinkCmd linkCmd = null;
  68. try {
  69. linkCmd = Optional.ofNullable(Linkproto.LinkCmd.parseFrom(message.getPayload().asByteBuffer())).orElseThrow(() -> new BusinessException(500,"解析出错了"));
  70. BaseDispatch<Linkproto.LinkCmd> dispatch = dispatchFactory.getCommand(linkCmd.getTypeValue());
  71. log.info("处理session,{},消息实体,{},类型,{},dispatch:{}",session,linkCmd,linkCmd.getTypeValue(),dispatch);
  72. dispatch.excuted(session, linkCmd);
  73. } catch (InvalidProtocolBufferException e) {
  74. e.printStackTrace();
  75. }
  76. }else if(message.getType().equals(WebSocketMessage.Type.TEXT)){
  77. String content = message.getPayloadAsText();
  78. log.info("收到文本消息:{}",content);
  79. Message msg = null;
  80. try{
  81. msg = JSONObject.parseObject(content, Message.class);
  82. }catch (Exception e){
  83. JSONObject obj = new JSONObject();
  84. obj.put("content","无法理解你发过来的消息内容,不予处理:"+content);
  85. obj.put("msgType",Linkproto.LinkCmdType.LINK_CMD_ZERO_VALUE);
  86. session.send(Flux.just(session.textMessage(obj.toJSONString()))).then().toProcessor();
  87. log.error("解析消息内容出错");
  88. return;
  89. }
  90. BaseDispatch<Linkproto.LinkCmd> dispatch = dispatchFactory.getCommand(msg.getMsgType());
  91. if(dispatch != null){
  92. dispatch.executeMsg(session, msg);
  93. }
  94. }else if(message.getType().equals(WebSocketMessage.Type.PING)){
  95. session.send(Flux.just(session.pongMessage(s -> s.wrap(new byte[256]))));
  96. log.info("收到ping消息");
  97. }else if(message.getType().equals(WebSocketMessage.Type.PONG)){
  98. log.info("收到pong消息");
  99. if(systemconfig.getPingInterval() != 0){
  100. executor.schedule(() -> session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
  101. }
  102. }
  103. }).doOnError(e -> {
  104. e.printStackTrace();
  105. log.error("doOnError");
  106. }).doOnRequest(r -> {
  107. log.info("doOnRequest");
  108. }).then();
  109. }

这边显示的是整个从连接建立到连接断开的生命周期,可以区区分二进制消息还是文本消息,发送消息时,一定要加上toProcessor(),不然不会发送。

如果要发送消息到其它客户端,需要在后台将连接过来的session保存起来,根据用户名或者其它方式保存之后,获取到session进行发送:如下面这个SocketSessionRegistry类

  1. import com.mti.enums.SocketCloseStatus;
  2. import lombok.extern.slf4j.Slf4j;
  3. import org.springframework.stereotype.Service;
  4. import org.springframework.util.CollectionUtils;
  5. import org.springframework.util.StringUtils;
  6. import org.springframework.web.reactive.socket.WebSocketSession;
  7. import java.util.*;
  8. import java.util.concurrent.ConcurrentHashMap;
  9. import java.util.concurrent.ConcurrentMap;
  10. import java.util.concurrent.CopyOnWriteArraySet;
  11. import java.util.concurrent.CountDownLatch;
  12. /**
  13. *用户session记录类
  14. *
  15. * @author zhaoyj
  16. * @date 2019/3/12
  17. */
  18. @Service
  19. @Slf4j
  20. public class SocketSessionRegistry {
  21. /**
  22. * 这个集合存储session
  23. */
  24. private final ConcurrentMap<String, Set<String>> userSessionIds = new ConcurrentHashMap<>();
  25. private final ConcurrentMap<String, WebSocketSession> clientInfoSessionIds = new ConcurrentHashMap<>();
  26. private final ConcurrentMap<String,String> sessionIdUser = new ConcurrentHashMap<>();
  27. private ConcurrentMap<String, CountDownLatch> cacheTimestamp = new ConcurrentHashMap<>();
  28. private final Object lock = new Object();
  29. /**
  30. * 获取sessionId
  31. *
  32. * @param user
  33. * @return
  34. */
  35. private Set<String> getSessionIds(String user) {
  36. Set<String> set = this.userSessionIds.get(user);
  37. return set != null ? set : Collections.emptySet();
  38. }
  39. /**
  40. * 获取用户session
  41. * @param user
  42. * @return
  43. */
  44. public Collection<WebSocketSession> getSessionByUser(String user){
  45. Set<String> sessionIds = Optional.ofNullable(getSessionIds(user)).orElse(new CopyOnWriteArraySet<>());
  46. List<WebSocketSession> sessions = new ArrayList<>();
  47. for (String sessionId : sessionIds) {
  48. sessions.add(clientInfoSessionIds.get(sessionId));
  49. }
  50. return sessions;
  51. }
  52. /**
  53. * 获取用户session
  54. * @param users
  55. * @return
  56. */
  57. public Collection<WebSocketSession> getSessionByUsers(Collection<String> users){
  58. List<WebSocketSession> sessions = new ArrayList<>();
  59. if(!CollectionUtils.isEmpty(users)){
  60. for (String user : users) {
  61. sessions.addAll(getSessionByUser(user));
  62. }
  63. }
  64. return sessions;
  65. }
  66. /**
  67. * 获取所有session
  68. * @return Collection<WebSocketSession>
  69. */
  70. public Collection<WebSocketSession> getAllSessions(){
  71. return clientInfoSessionIds.values();
  72. }
  73. /**
  74. * 获取所有session
  75. *
  76. * @return
  77. */
  78. public ConcurrentMap<String, Set<String>> getAllSessionIds() {
  79. return this.userSessionIds;
  80. }
  81. /**
  82. * 获取所有session
  83. *
  84. * @return
  85. */
  86. public ConcurrentMap<String, WebSocketSession> getAllSessionWebSocketInfos() {
  87. return this.clientInfoSessionIds;
  88. }
  89. /**
  90. * register session
  91. *
  92. * @param user
  93. * @param sessionId
  94. */
  95. private void registerSessionId(String user, String sessionId) {
  96. synchronized (this.lock) {
  97. Set<String> set = this.userSessionIds.get(user);
  98. if (set == null) {
  99. set = new CopyOnWriteArraySet<>();
  100. this.userSessionIds.put(user, set);
  101. }
  102. set.add(sessionId);
  103. }
  104. }
  105. /**
  106. * 保存session
  107. * @param session WebSocketSession
  108. */
  109. public void registerSession(WebSocketSession session,String user){
  110. if(StringUtils.isEmpty(user)){
  111. user = parseUserByURI(session).get("user");
  112. }
  113. if(!StringUtils.isEmpty(user)){
  114. String sessionId = session.getId();
  115. registerSessionId(user,sessionId);
  116. registerSessionId(session);
  117. sessionIdUser.putIfAbsent(sessionId,user);
  118. }
  119. }
  120. /**
  121. * 从session里面解析参数
  122. * @param session
  123. * @return
  124. */
  125. private Map<String, String> parseUserByURI(WebSocketSession session){
  126. Map<String, String> map = new HashMap<>();
  127. String[] params = Optional.ofNullable(session.getHandshakeInfo().getUri().getQuery()).orElse("").split("&");
  128. for (String param : params) {
  129. String[] temp = param.split("=");
  130. if(temp.length == 2){
  131. map.put(temp[0],temp[1]);
  132. }
  133. }
  134. return map;
  135. }
  136. public WebSocketSession getSessionBySessionId(String sessionId){
  137. return this.clientInfoSessionIds.get(sessionId);
  138. }
  139. private void registerSessionId(WebSocketSession websocketInfo) {
  140. String sessionId = websocketInfo.getId();
  141. CountDownLatch signal = cacheTimestamp.putIfAbsent(sessionId, new CountDownLatch(1));
  142. if (signal == null) {
  143. signal = cacheTimestamp.get(sessionId);
  144. try {
  145. if (!clientInfoSessionIds.containsKey(sessionId)) {
  146. WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
  147. if (set == null) {
  148. clientInfoSessionIds.putIfAbsent(sessionId, websocketInfo);
  149. }
  150. }
  151. } finally {
  152. signal.countDown();
  153. cacheTimestamp.remove(sessionId);
  154. }
  155. } else {
  156. try {
  157. signal.await();
  158. } catch (InterruptedException e) {
  159. e.printStackTrace();
  160. }
  161. }
  162. }
  163. private void unregisterSessionId(String userName, String sessionId) {
  164. synchronized (this.lock) {
  165. Set set = this.userSessionIds.get(userName);
  166. if (set != null && set.remove(sessionId) && set.isEmpty()) {
  167. this.userSessionIds.remove(userName);
  168. }
  169. }
  170. }
  171. private void unregisterSessionId(String sessionId) {
  172. synchronized (this.lock) {
  173. WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
  174. if (set != null) {
  175. this.clientInfoSessionIds.remove(sessionId);
  176. }
  177. }
  178. }
  179. public void unregisterSession(WebSocketSession session){
  180. String sessionId = session.getId();
  181. String user = sessionIdUser.get(sessionId);
  182. if(!StringUtils.isEmpty(user)){
  183. unregisterSessionId(sessionId);
  184. unregisterSessionId(user,sessionId);
  185. sessionIdUser.remove(sessionId);
  186. }
  187. }
  188. public void checkAndRemove(WebSocketSession session){
  189. String sessionId = session.getId();
  190. if(!this.clientInfoSessionIds.containsKey(sessionId)){
  191. log.info("sessionId:{} 10秒内没有登陆,关掉它",sessionId);
  192. session.close(SocketCloseStatus.UN_LOGIN.getCloseStatus()).toProcessor();
  193. }else{
  194. log.info("sessinId:{}已经登陆,是合法的",sessionId);
  195. }
  196. }
  197. }
userSessionIds是保存用记所属的sessionId列表的,因为同一个用户可能会在不同地方登陆,会有多个session
clientInfoSessionIds这个是保存session的,可以根据sessionId对应到用户。

这几周慢慢摸索出来的结果,网上资料很少,官网上的也不是很全,可能有不对的地方,在此做个记录!