1、配置入口:
- import com.mti.handler.MessageHandler;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.context.annotation.Bean;
- import org.springframework.context.annotation.Configuration;
- import org.springframework.core.Ordered;
- import org.springframework.web.reactive.HandlerMapping;
- import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
- import org.springframework.web.reactive.socket.WebSocketHandler;
- import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter;
-
- import java.util.HashMap;
- import java.util.Map;
- /**
- * ThreadConfig class
- *
- * @author zhaoyj
- * @date 2019/3/12
- */
- @Configuration
- public class WebSocketConfiguration {
- @Autowired
- @Bean
- public HandlerMapping webSocketMapping(final MessageHandler echoHandler) {
- final Map<String, WebSocketHandler> map = new HashMap<>();
- map.put("/echo", echoHandler);
- final SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping();
- mapping.setOrder(Ordered.HIGHEST_PRECEDENCE);
- mapping.setUrlMap(map);
- return mapping;
- }
-
- @Bean
- public WebSocketHandlerAdapter handlerAdapter() {
- return new WebSocketHandlerAdapter();
- }
- }
2、配置Handler
- import com.alibaba.fastjson.JSONObject;
- import com.google.protobuf.InvalidProtocolBufferException;
- import com.mti.configuration.Systemconfig;
- import com.mti.enums.ReferenceMsgType;
- import com.mti.exception.BusinessException;
- import com.mti.handler.up.StreamReferenceReq;
- import com.mti.proto.Linkproto;
- import com.mti.vo.Message;
- import com.mti.websocket.SocketSessionRegistry;
- import lombok.extern.slf4j.Slf4j;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.core.io.buffer.DataBufferFactory;
- import org.springframework.core.task.TaskExecutor;
- import org.springframework.stereotype.Component;
- import org.springframework.web.reactive.socket.WebSocketHandler;
- import org.springframework.web.reactive.socket.WebSocketMessage;
- import org.springframework.web.reactive.socket.WebSocketSession;
- import reactor.core.publisher.Flux;
- import reactor.core.publisher.Mono;
-
- import java.util.Optional;
- import java.util.concurrent.ScheduledThreadPoolExecutor;
- import java.util.concurrent.TimeUnit;
-
- /**
- * MessageHandler class
- *
- * @author zhaoyj
- * @date 2019/3/12
- */
- @Component
- @Slf4j
- public class MessageHandler implements WebSocketHandler {
-
- @Autowired
- private SocketSessionRegistry sessionRegistry;
- @Autowired
- private ScheduledThreadPoolExecutor executor;
- @Autowired
- private Systemconfig systemconfig;
-
- @Autowired
- DispatchFactory dispatchFactory;
-
- @Autowired
- TaskExecutor taskExecutor;
-
- @Override
- public Mono<Void> handle(WebSocketSession session) {
-
- return session.receive().doOnSubscribe(s -> {
- log.info("发起连接:{}",s);
- /**
- * 你有10秒时间登陆,不登陆就关掉连接;并且不给任何错误信息
- */
- if(systemconfig.getLoginInterval() != 0){
- executor.schedule(() -> sessionRegistry.checkAndRemove(session),systemconfig.getLoginInterval(),TimeUnit.SECONDS);
- }
- if(systemconfig.getPingInterval() != 0){
- executor.schedule(() -> session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
- }
- }).doOnTerminate(() -> {
- sessionRegistry.unregisterSession(session);
- StreamReferenceReq req = (StreamReferenceReq) dispatchFactory.getCommand(ReferenceMsgType.SEND_VALUE);
- taskExecutor.execute(() -> Optional.ofNullable(req.removeSession(session)).ifPresent(list -> list.forEach(req::sendStopStreamConfig)));
- log.info("doOnTerminate");
- }).doOnComplete(() -> {
- log.info("doOnComplete");
- }).doOnCancel(() -> {
- log.info("doOnCancel");
- }).doOnNext(message -> {
- if(message.getType().equals(WebSocketMessage.Type.BINARY)){
- log.info("收到二进制消息");
- Linkproto.LinkCmd linkCmd = null;
- try {
- linkCmd = Optional.ofNullable(Linkproto.LinkCmd.parseFrom(message.getPayload().asByteBuffer())).orElseThrow(() -> new BusinessException(500,"解析出错了"));
- BaseDispatch<Linkproto.LinkCmd> dispatch = dispatchFactory.getCommand(linkCmd.getTypeValue());
- log.info("处理session,{},消息实体,{},类型,{},dispatch:{}",session,linkCmd,linkCmd.getTypeValue(),dispatch);
- dispatch.excuted(session, linkCmd);
- } catch (InvalidProtocolBufferException e) {
- e.printStackTrace();
- }
- }else if(message.getType().equals(WebSocketMessage.Type.TEXT)){
- String content = message.getPayloadAsText();
- log.info("收到文本消息:{}",content);
- Message msg = null;
- try{
- msg = JSONObject.parseObject(content, Message.class);
- }catch (Exception e){
- JSONObject obj = new JSONObject();
- obj.put("content","无法理解你发过来的消息内容,不予处理:"+content);
- obj.put("msgType",Linkproto.LinkCmdType.LINK_CMD_ZERO_VALUE);
- session.send(Flux.just(session.textMessage(obj.toJSONString()))).then().toProcessor();
- log.error("解析消息内容出错");
- return;
- }
- BaseDispatch<Linkproto.LinkCmd> dispatch = dispatchFactory.getCommand(msg.getMsgType());
- if(dispatch != null){
- dispatch.executeMsg(session, msg);
- }
- }else if(message.getType().equals(WebSocketMessage.Type.PING)){
- session.send(Flux.just(session.pongMessage(s -> s.wrap(new byte[256]))));
- log.info("收到ping消息");
- }else if(message.getType().equals(WebSocketMessage.Type.PONG)){
- log.info("收到pong消息");
- if(systemconfig.getPingInterval() != 0){
- executor.schedule(() -> session.send(Flux.just(session.pingMessage(DataBufferFactory::allocateBuffer))).toProcessor(),systemconfig.getPingInterval(), TimeUnit.SECONDS);
- }
- }
- }).doOnError(e -> {
- e.printStackTrace();
- log.error("doOnError");
- }).doOnRequest(r -> {
- log.info("doOnRequest");
- }).then();
- }
这边显示的是整个从连接建立到连接断开的生命周期,可以区区分二进制消息还是文本消息,发送消息时,一定要加上toProcessor(),不然不会发送。
如果要发送消息到其它客户端,需要在后台将连接过来的session保存起来,根据用户名或者其它方式保存之后,获取到session进行发送:如下面这个SocketSessionRegistry类
- import com.mti.enums.SocketCloseStatus;
- import lombok.extern.slf4j.Slf4j;
- import org.springframework.stereotype.Service;
- import org.springframework.util.CollectionUtils;
- import org.springframework.util.StringUtils;
- import org.springframework.web.reactive.socket.WebSocketSession;
-
- import java.util.*;
- import java.util.concurrent.ConcurrentHashMap;
- import java.util.concurrent.ConcurrentMap;
- import java.util.concurrent.CopyOnWriteArraySet;
- import java.util.concurrent.CountDownLatch;
-
- /**
- *用户session记录类
- *
- * @author zhaoyj
- * @date 2019/3/12
- */
- @Service
- @Slf4j
- public class SocketSessionRegistry {
-
- /**
- * 这个集合存储session
- */
- private final ConcurrentMap<String, Set<String>> userSessionIds = new ConcurrentHashMap<>();
-
- private final ConcurrentMap<String, WebSocketSession> clientInfoSessionIds = new ConcurrentHashMap<>();
-
- private final ConcurrentMap<String,String> sessionIdUser = new ConcurrentHashMap<>();
- private ConcurrentMap<String, CountDownLatch> cacheTimestamp = new ConcurrentHashMap<>();
- private final Object lock = new Object();
-
-
- /**
- * 获取sessionId
- *
- * @param user
- * @return
- */
- private Set<String> getSessionIds(String user) {
- Set<String> set = this.userSessionIds.get(user);
- return set != null ? set : Collections.emptySet();
- }
-
- /**
- * 获取用户session
- * @param user
- * @return
- */
- public Collection<WebSocketSession> getSessionByUser(String user){
- Set<String> sessionIds = Optional.ofNullable(getSessionIds(user)).orElse(new CopyOnWriteArraySet<>());
- List<WebSocketSession> sessions = new ArrayList<>();
- for (String sessionId : sessionIds) {
- sessions.add(clientInfoSessionIds.get(sessionId));
- }
- return sessions;
- }
-
- /**
- * 获取用户session
- * @param users
- * @return
- */
- public Collection<WebSocketSession> getSessionByUsers(Collection<String> users){
- List<WebSocketSession> sessions = new ArrayList<>();
- if(!CollectionUtils.isEmpty(users)){
- for (String user : users) {
- sessions.addAll(getSessionByUser(user));
- }
- }
- return sessions;
- }
-
- /**
- * 获取所有session
- * @return Collection<WebSocketSession>
- */
- public Collection<WebSocketSession> getAllSessions(){
- return clientInfoSessionIds.values();
- }
-
- /**
- * 获取所有session
- *
- * @return
- */
- public ConcurrentMap<String, Set<String>> getAllSessionIds() {
- return this.userSessionIds;
- }
- /**
- * 获取所有session
- *
- * @return
- */
- public ConcurrentMap<String, WebSocketSession> getAllSessionWebSocketInfos() {
- return this.clientInfoSessionIds;
- }
- /**
- * register session
- *
- * @param user
- * @param sessionId
- */
- private void registerSessionId(String user, String sessionId) {
-
- synchronized (this.lock) {
- Set<String> set = this.userSessionIds.get(user);
- if (set == null) {
- set = new CopyOnWriteArraySet<>();
- this.userSessionIds.put(user, set);
- }
- set.add(sessionId);
- }
- }
-
- /**
- * 保存session
- * @param session WebSocketSession
- */
- public void registerSession(WebSocketSession session,String user){
- if(StringUtils.isEmpty(user)){
- user = parseUserByURI(session).get("user");
- }
- if(!StringUtils.isEmpty(user)){
- String sessionId = session.getId();
- registerSessionId(user,sessionId);
- registerSessionId(session);
- sessionIdUser.putIfAbsent(sessionId,user);
- }
- }
- /**
- * 从session里面解析参数
- * @param session
- * @return
- */
- private Map<String, String> parseUserByURI(WebSocketSession session){
- Map<String, String> map = new HashMap<>();
- String[] params = Optional.ofNullable(session.getHandshakeInfo().getUri().getQuery()).orElse("").split("&");
- for (String param : params) {
- String[] temp = param.split("=");
- if(temp.length == 2){
- map.put(temp[0],temp[1]);
- }
- }
- return map;
- }
- public WebSocketSession getSessionBySessionId(String sessionId){
- return this.clientInfoSessionIds.get(sessionId);
- }
- private void registerSessionId(WebSocketSession websocketInfo) {
- String sessionId = websocketInfo.getId();
- CountDownLatch signal = cacheTimestamp.putIfAbsent(sessionId, new CountDownLatch(1));
- if (signal == null) {
- signal = cacheTimestamp.get(sessionId);
- try {
- if (!clientInfoSessionIds.containsKey(sessionId)) {
- WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
- if (set == null) {
- clientInfoSessionIds.putIfAbsent(sessionId, websocketInfo);
- }
- }
- } finally {
- signal.countDown();
- cacheTimestamp.remove(sessionId);
- }
- } else {
- try {
- signal.await();
- } catch (InterruptedException e) {
- e.printStackTrace();
- }
- }
- }
-
- private void unregisterSessionId(String userName, String sessionId) {
-
- synchronized (this.lock) {
- Set set = this.userSessionIds.get(userName);
- if (set != null && set.remove(sessionId) && set.isEmpty()) {
- this.userSessionIds.remove(userName);
- }
- }
- }
- private void unregisterSessionId(String sessionId) {
-
- synchronized (this.lock) {
- WebSocketSession set = this.clientInfoSessionIds.get(sessionId);
- if (set != null) {
- this.clientInfoSessionIds.remove(sessionId);
- }
- }
- }
-
- public void unregisterSession(WebSocketSession session){
- String sessionId = session.getId();
- String user = sessionIdUser.get(sessionId);
- if(!StringUtils.isEmpty(user)){
- unregisterSessionId(sessionId);
- unregisterSessionId(user,sessionId);
- sessionIdUser.remove(sessionId);
- }
- }
-
- public void checkAndRemove(WebSocketSession session){
- String sessionId = session.getId();
- if(!this.clientInfoSessionIds.containsKey(sessionId)){
- log.info("sessionId:{} 10秒内没有登陆,关掉它",sessionId);
- session.close(SocketCloseStatus.UN_LOGIN.getCloseStatus()).toProcessor();
- }else{
- log.info("sessinId:{}已经登陆,是合法的",sessionId);
- }
- }
- }
userSessionIds是保存用记所属的sessionId列表的,因为同一个用户可能会在不同地方登陆,会有多个session
clientInfoSessionIds这个是保存session的,可以根据sessionId对应到用户。
这几周慢慢摸索出来的结果,网上资料很少,官网上的也不是很全,可能有不对的地方,在此做个记录!