Netty实现WebSocket服务端及客户端(验签操作)

/ 后端 / 没有评论 / 533浏览

一.URI带参数方式:

1、问题:

Netty对WebSocket提供了很好的支持,在pipeline里添加一个WebSocketServerProtocolHandler就可以方便的暴露一个ws接口出去。但是,开发中却遇到一点小问题,需要在ws的URI带上参数queryString(如:/im/ws?w=221100234&t=99),然而这样会导致ws连接无法建立,浏览器报错:Connection closed before receiving a handshake response。

2、解决:

因为全安问题,ws是不推荐通过queryString携带信息的,所以Netty里WebSocketServerProtocolHandler默认的构造器是要求对连接建立时传入的URI与程序指定的路径完全匹配。 checkStartsWith即是控制根据startsWith来做URI匹配,所以改用重载的构造器即可解决问题(注意queryString上不要携带敏感信息)。 checkStartsWith使用的地方如下图中WebSocketServerProtocolHandshakeHandler类:

示例代码

netty服务端(简单鉴权):

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;

import java.time.LocalDateTime;

public class NettyTest {

    public static void main(String[] args) {
        NioEventLoopGroup boos = new NioEventLoopGroup();
        NioEventLoopGroup worker = new NioEventLoopGroup();
        try {

            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(boos, worker).channel(NioServerSocketChannel.class).handler(new LoggingHandler(LogLevel.INFO)).childHandler(new ChannelInitializer<SocketChannel>() {
                @Override
                protected void initChannel(SocketChannel ch) throws Exception {
                    ChannelPipeline pipeline = ch.pipeline();

                    pipeline.addLast(new HttpServerCodec());
                    pipeline.addLast(new ChunkedWriteHandler());
                    pipeline.addLast(new HttpObjectAggregator(8192));

                    pipeline.addLast(new MyHttpHandler());
                    pipeline.addLast(new IdleStateHandler(10, 0, 0));
                    //此处构造参数为false,判断uri前部分符合即可,否则需要完全一致
                    pipeline.addLast(new WebSocketServerProtocolHandler("/hello", false));
                    //自定义handler
                    pipeline.addLast(new MyWebSocketHandler());
                }
            });

            ChannelFuture sync = bootstrap.bind(8888).sync();
            sync.channel().closeFuture().sync();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            boos.shutdownGracefully();
            worker.shutdownGracefully();
        }
    }

    static class MyHttpHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
        @Override
        protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
           Map<CharSequence, CharSequence> queryMap = UrlBuilder.ofHttp(msg.uri()).getQuery().getQueryMap();
            CharSequence token = queryMap.get("token");
            if (token == null) {
                token = msg.headers().get("token");
            }
        
           //提前结束传递,关闭连接
           //if (token == null || !token.equals("1")) {
           //    ctx.writeAndFlush("token错误").addListener(ChannelFutureListener.CLOSE);
           //    return;
           //}

              ctx.channel().attr(AttributeKey.valueOf("token")).setIfAbsent(token.toString());
       
             ctx.fireChannelRead(msg.retain());
        }
    }

    static class MyWebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
        @Override
        protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
            System.out.println("服务器收到消息:" + msg.text());

            String respMsg = "服务器时间:" + LocalDateTime.now() + ";" + msg.text();

            ctx.writeAndFlush(new TextWebSocketFrame(respMsg));
        }

        //客户端链接触发
        @Override
        public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
            //asLongText 唯一,asShortText 非唯一
            System.out.println("handlerAdded asLongText" + ctx.channel().id().asLongText());
            System.out.println("handlerAdded asShortText" + ctx.channel().id().asShortText());
            //super.handlerAdded(ctx);
        }

        //客户端退出触发
        @Override
        public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
            //asLongText 唯一,asShortText 非唯一
            System.out.println("handlerRemoved asLongText" + ctx.channel().id().asLongText());
            System.out.println("handlerRemoved asShortText" + ctx.channel().id().asShortText());
            //super.handlerRemoved(ctx);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            System.out.println("异常发生" + cause.getMessage());
            //super.exceptionCaught(ctx, cause);
            ctx.close();
        }
        
        //客户端心跳监测
        @Override
        public void userEventTriggered(ChannelHandlerContext ctx, Object obj) throws Exception {
               //协议握手成功完成,进行验签
              if (obj instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
                AttributeKey<String> attributeKey = AttributeKey.valueOf("token");
                //从通道中获取用户token
                String token = ctx.channel().attr(attributeKey).get();
                //校验token逻辑
                if (!token.equals("1")) {
                    //如果token校验不通过,发送连接关闭的消息给客户端,设置自定义code和msg用来区分下服务器是因为token不对才导致关闭
                    ctx.writeAndFlush(new CloseWebSocketFrame()).addListener(ChannelFutureListener.CLOSE);
                }
            }            

            if (obj instanceof IdleStateEvent) {
                IdleStateEvent event = (IdleStateEvent) obj;
                if (event.state() == IdleState.READER_IDLE) {
                    System.out.println("客户端读超时");
                    ctx.writeAndFlush(new TextWebSocketFrame("读超时")).addListener(ChannelFutureListener.CLOSE);
                } else if (event.state() == IdleState.WRITER_IDLE) {
                    System.out.println("客户端写超时");
                } else if (event.state() == IdleState.ALL_IDLE) {
                    System.out.println("客户端所有操作超时");
                }
            }
        }
        
    }
}

netty客户端:

package websocket;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler;
import io.netty.util.CharsetUtil;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.URISyntaxException;

public class NettyWebsocketClientTest {

    public static void main(String[] args) throws InterruptedException {
        try {
            start();
        } catch (URISyntaxException e) {
            e.printStackTrace();
        }
    }

    public static void start() throws URISyntaxException {
        URI uri = new URI("ws://127.0.0.1/hello?token=2");
        DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders();
        httpHeaders.add("token", "2");

        TestSocketClientHandler handler =
                new TestSocketClientHandler(
                        WebSocketClientHandshakerFactory.newHandshaker(
                                uri, WebSocketVersion.V13, null, true, httpHeaders));
        EventLoopGroup group = new NioEventLoopGroup();
        try {
            Bootstrap b = new Bootstrap();
            b.group(group)
                    .channel(NioSocketChannel.class)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) {
                            ChannelPipeline p = ch.pipeline();
                            p.addLast(new HttpClientCodec());
                            p.addLast(new HttpObjectAggregator(8192));
                            p.addLast(WebSocketClientCompressionHandler.INSTANCE);
                            p.addLast(handler);
                        }
                    });

            Channel ch = b.connect(uri.getHost(), 8888).sync().channel();
            handler.handshakeFuture().sync();

            BufferedReader console = new BufferedReader(new InputStreamReader(System.in));
            while (true && ch.isOpen()) {
                System.out.println();
                String msg = console.readLine();
                if (msg == null) {
                    break;
                } else if ("再见".equalsIgnoreCase(msg)) {
                    ch.writeAndFlush(new CloseWebSocketFrame());
                    ch.closeFuture().sync();
                    break;
                } else if ("ping".equalsIgnoreCase(msg)) {
                    WebSocketFrame frame = new PingWebSocketFrame(Unpooled.wrappedBuffer(new byte[]{8, 1, 8, 1}));
                    ch.writeAndFlush(frame);
                } else {
                    WebSocketFrame frame = new TextWebSocketFrame(msg);
                    ch.writeAndFlush(frame);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            group.shutdownGracefully();
        }
    }


    public static class TestSocketClientHandler extends SimpleChannelInboundHandler<Object> {

        private final WebSocketClientHandshaker handshaker;
        private ChannelPromise handshakeFuture;

        public TestSocketClientHandler(WebSocketClientHandshaker handshaker) {
            this.handshaker = handshaker;
        }

        public ChannelFuture handshakeFuture() {
            return handshakeFuture;
        }

        @Override
        public void handlerAdded(ChannelHandlerContext ctx) {
            handshakeFuture = ctx.newPromise();
        }

        @Override
        public void channelActive(ChannelHandlerContext ctx) {
            handshaker.handshake(ctx.channel());
        }

        @Override
        public void channelInactive(ChannelHandlerContext ctx) {
            System.out.println("channelInactive!");
        }

        @Override
        public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
            Channel ch = ctx.channel();
            if (!handshaker.isHandshakeComplete()) {
                try {
                    handshaker.finishHandshake(ch, (FullHttpResponse) msg);
                    System.out.println("websocket Handshake 完成!");
                    handshakeFuture.setSuccess();
                } catch (WebSocketHandshakeException e) {
                    System.out.println("websocket连接失败!");
                    handshakeFuture.setFailure(e);
                }
                return;
            }

            if (msg instanceof FullHttpResponse) {
                FullHttpResponse response = (FullHttpResponse) msg;
                throw new IllegalStateException(
                        "Unexpected FullHttpResponse (getStatus=" + response.status() +
                                ", content=" + response.content().toString(CharsetUtil.UTF_8) + ')');
            }

            WebSocketFrame frame = (WebSocketFrame) msg;
            if (frame instanceof TextWebSocketFrame) {
                TextWebSocketFrame textFrame = (TextWebSocketFrame) frame;
                System.out.println("接收到TXT消息: " + textFrame.text());
            } else if (frame instanceof PongWebSocketFrame) {
                System.out.println("接收到pong消息");
            } else if (frame instanceof CloseWebSocketFrame) {
                System.out.println("接收到closing消息");
                ch.close();
            }
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            // 异常处理
            System.out.println("出现异常");
            if (!handshakeFuture.isDone()) {
                handshakeFuture.setFailure(cause);
            }
            ctx.close();
        }
    }
}

html:

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>hello</title>
</head>
<script>
    var socket;
    if(window.WebSocket){
        socket = new WebSocket("ws://localhost:8888/hello?token=1");
        //可以收到服务器端发送的消息
        socket.onmessage = function(p1){
            var rt = document.getElementById('responseTxt');
            rt.value = rt.value + "\n" + p1.data+"\n";
        };
        //链接成功
        socket.onopen = function(p1){
            var rt = document.getElementById('responseTxt');
            rt.value = "链接开启\n";
        };

        //链接关闭
        socket.onclose = function(p1){
            var rt = document.getElementById('responseTxt');
            rt.value = rt.value+"链接关闭了\n";
        };
    }else{
        alert("当前浏览器不支持websocket");
    }

    //发送消息
    function send(){
        var msg = document.getElementById("msg").value;

        if(!window.socket){
            return;
        }
        if(socket.readyState == WebSocket.OPEN){
            socket.send(msg);
            document.getElementById("msg").value = '';
        }else{
            alert("链接没有打开");
        }


    }
</script>
<body>
<form onsubmit="return false" >
    <textarea id = "msg" name = "message" style="height: 300px;width: 300px"></textarea>
    <input type="button" value="发送消息" onclick="send()" >

    <textarea id = "responseTxt" style="height: 300px;width: 300px"></textarea>
    <input type="button" value="清空"
           onclick="document.getElementById('responseTxt').value = ''" >
</form>

</body>

</html>

二.使用websocket私有协议方式 (Sec_WebSocket-Protocol)

三.使用netty作为反向代理配置

map $http_upgrade $connection_upgrade {
        default upgrade;
        '' close;
    }

   
	
    server {
     listen     80;

     location /hello {
                 proxy_pass http://127.0.0.1:8888;
                 proxy_read_timeout 300s;
                 proxy_send_timeout 300s;
                 
                 proxy_set_header Host $host;
                 proxy_set_header X-Real-IP $remote_addr;
                 proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
                 
                 proxy_http_version 1.1;
                 proxy_set_header Upgrade $http_upgrade;
                 proxy_set_header Connection $connection_upgrade;
                }
    }