最近有不少网友在咨询netty client中,netty的channel链接池应该如何设计。这是个稍微有些复杂的主题,牵扯到蛮多技术点,要想在网上找到相关的又相对完整的参考文章,确实不太容易。java
在本篇文章中,会给出其中一种解决方案,而且附带完整的可运行的代码。若是网友有更好的方案,能够回复本文,咱们一块儿讨论讨论,一块儿开阔思路和眼界。bootstrap
阅读本文以前须要具有一些基础知识api
一、知道netty的一些基础知识,好比ByteBuf类的相关api;
二、知道netty的执行流程;
三、 必须阅读过我以前写的netty实战-自定义解码器处理半包消息,由于本文部分代码来自这篇文章。数组
如今微服务很是的热门,也有不少公司在用。微服务框架中,若是是使用thrift、grpc来做为数据序列化框架的话,一般都会生成一个SDK给客户端用户使用。客户端只要使用这个SDK,就能够方便的调用服务端的微服务接口。本文讨论的就是使用SDK的netty客户端,它的netty channel链接池的设计方案。至于netty http client的channel链接池设计,基于http的,是另一个主题了,须要另外写文章来讨论的。缓存
DB链接池中,当某个线程获取到一个db connection后,在读取数据或者写数据时,若是线程没有操做完,这个db connection一直被该线程独占着,直到线程执行完任务。若是netty client的channel链接池设计也是使用这种独占的方式的话,有几个问题。服务器
一、netty中channel的writeAndFlush方法,调用完后是不用等待返回结果的,writeAndFlush一被调用,立刻返回。对于这种状况,是彻底不必让线程独占一个channel的。
二、使用相似DB pool的方式,从池子里拿链接,用完后返回,这里的一进一出,须要考虑并发锁的问题。另外,若是请求量很大的时候,链接会不够用,其余线程也只能等待其余线程释放链接。并发
所以不太建议使用上面的方式来设计netty channel链接池,channel独占的代价太大了。可使用Channel数组的形式, 复用netty的channel。当线程要须要Channel的时候,随机从数组选中一个Channel,若是Channel还未创建,则建立一个。若是线程选中的Channel已经创建了,则复用这个Channel。框架
假设channel数组的长度为4dom
private Channel[] channels = new Channel[4];复制代码
当外部系统请求client的时候,client从channels数组中随机挑选一个channel,若是该channel还没有创建,则触发创建channel的逻辑。不管有多少请求,都是复用这4个channel。假设有10个线程,那么部分线程可能会使用相同的channel来发送数据和接收数据。由于是随机选择一个channel的,多个线程命中同一个channel的机率仍是很大的。以下图异步
10个线程中,可能有3个线程都是使用channel2来发送数据的。这个会引入另一个问题。thread1经过channel2发送一条消息msg1到服务端,thread2也经过channel2发送一条消息msg2到服务端,当服务端处理完数据,经过channel2返回数据给客户端的时候,如何区分哪条消息是哪一个线程的呢?若是不作区分,万一thread1拿到的结果实际上是thread2要的结果,怎么办?
那么如何作到让thread1和thread2拿到它们本身想要的结果呢?
以前我在netty实战-自定义解码器处理半包消息一文中提到,自定义消息的时候,一般会在消息中加入一个序列号,用来惟一标识消息的。当thread1发送消息时,往消息中插入一个惟一的消息序列号,同时为thread1创建一个callback回调程序,当服务端返回消息的时候,根据消息中的序列号从对应的callback程序获取结果。这样就能够解决上面说到的问题。
消息格式
消息、消息seq以及callback对应关系
OK,下面就基于上面的设计来进行编码。
先来实现netty客户端,设置10个线程并发获取channel,为了达到真正的并发,利用CountDownLatch来作开关,同时channel链接池设置4个channel。
package nettyinaction.nettyclient.channelpool.client;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import nettyinaction.nettyclient.channelpool.ChannelUtils;
import nettyinaction.nettyclient.channelpool.IntegerFactory;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
public class SocketClient {
public static void main(String[] args) throws InterruptedException {
//当全部线程都准备后,开闸,让全部线程并发的去获取netty的channel
final CountDownLatch countDownLatchBegin = new CountDownLatch(1);
//当全部线程都执行完任务后,释放主线程,让主线程继续执行下去
final CountDownLatch countDownLatchEnd = new CountDownLatch(10);
//netty channel池
final NettyChannelPool nettyChannelPool = new NettyChannelPool();
final Map<String, String> resultsMap = new HashMap<>();
//使用10个线程,并发的去获取netty channel
for (int i = 0; i < 10; i++) {
new Thread(new Runnable() {
@Override
public void run() {
try {
//先让线程block住
countDownLatchBegin.await();
Channel channel = null;
try {
channel = nettyChannelPool.syncGetChannel();
} catch (InterruptedException e) {
e.printStackTrace();
}
//为每一个线程创建一个callback,当消息返回的时候,在callback中获取结果
CallbackService callbackService = new CallbackService();
//给消息分配一个惟一的消息序列号
int seq = IntegerFactory.getInstance().incrementAndGet();
//利用Channel的attr方法,创建消息与callback的对应关系
ChannelUtils.putCallback2DataMap(channel,seq,callbackService);
synchronized (callbackService) {
UnpooledByteBufAllocator allocator = new UnpooledByteBufAllocator(false);
ByteBuf buffer = allocator.buffer(20);
buffer.writeInt(ChannelUtils.MESSAGE_LENGTH);
buffer.writeInt(seq);
String threadName = Thread.currentThread().getName();
buffer.writeBytes(threadName.getBytes());
buffer.writeBytes("body".getBytes());
//给netty 服务端发送消息,异步的,该方法会马上返回
channel.writeAndFlush(buffer);
//等待返回结果
callbackService.wait();
//解析结果,这个result在callback中已经解析到了。
ByteBuf result = callbackService.result;
int length = result.readInt();
int seqFromServer = result.readInt();
byte[] head = new byte[8];
result.readBytes(head);
String headString = new String(head);
byte[] body = new byte[4];
result.readBytes(body);
String bodyString = new String(body);
resultsMap.put(threadName, seqFromServer + headString + bodyString);
}
} catch (Exception e) {
e.printStackTrace();
}
finally {
countDownLatchEnd.countDown();
}
}
}).start();
}
//开闸,让10个线程并发获取netty channel
countDownLatchBegin.countDown();
//等10个线程执行完后,打印最终结果
countDownLatchEnd.await();
System.out.println("resultMap="+resultsMap);
}
public static class CallbackService{
public volatile ByteBuf result;
public void receiveMessage(ByteBuf receiveBuf) throws Exception {
synchronized (this) {
result = receiveBuf;
this.notify();
}
}
}
}复制代码
其中IntegerFactory类用于生成消息的惟一序列号
package nettyinaction.nettyclient.channelpool;
import java.util.concurrent.atomic.AtomicInteger;
public class IntegerFactory {
private static class SingletonHolder {
private static final AtomicInteger INSTANCE = new AtomicInteger();
}
private IntegerFactory(){}
public static final AtomicInteger getInstance() {
return SingletonHolder.INSTANCE;
}
}复制代码
而ChannelUtils类则用于创建channel、消息序列号和callback程序的对应关系。
package nettyinaction.nettyclient.channelpool;
import io.netty.channel.Channel;
import io.netty.util.AttributeKey;
import java.util.Map;
public class ChannelUtils {
public static final int MESSAGE_LENGTH = 16;
public static final AttributeKey<Map<Integer, Object>> DATA_MAP_ATTRIBUTEKEY = AttributeKey.valueOf("dataMap");
public static <T> void putCallback2DataMap(Channel channel, int seq, T callback) {
channel.attr(DATA_MAP_ATTRIBUTEKEY).get().put(seq, callback);
}
public static <T> T removeCallback(Channel channel, int seq) {
return (T) channel.attr(DATA_MAP_ATTRIBUTEKEY).get().remove(seq);
}
}复制代码
NettyChannelPool则负责建立netty的channel。
package nettyinaction.nettyclient.channelpool.client;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.Attribute;
import nettyinaction.nettyclient.channelpool.ChannelUtils;
import nettyinaction.nettyclient.channelpool.SelfDefineEncodeHandler;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
public class NettyChannelPool {
private Channel[] channels;
private Object [] locks;
private static final int MAX_CHANNEL_COUNT = 4;
public NettyChannelPool() {
this.channels = new Channel[MAX_CHANNEL_COUNT];
this.locks = new Object[MAX_CHANNEL_COUNT];
for (int i = 0; i < MAX_CHANNEL_COUNT; i++) {
this.locks[i] = new Object();
}
}
/** * 同步获取netty channel */
public Channel syncGetChannel() throws InterruptedException {
//产生一个随机数,随机的从数组中获取channel
int index = new Random().nextInt(MAX_CHANNEL_COUNT);
Channel channel = channels[index];
//若是能获取到,直接返回
if (channel != null && channel.isActive()) {
return channel;
}
synchronized (locks[index]) {
channel = channels[index];
//这里必须再次作判断,当锁被释放后,以前等待的线程已经能够直接拿到结果了。
if (channel != null && channel.isActive()) {
return channel;
}
//开始跟服务端交互,获取channel
channel = connectToServer();
channels[index] = channel;
}
return channel;
}
private Channel connectToServer() throws InterruptedException {
EventLoopGroup eventLoopGroup = new NioEventLoopGroup();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup)
.channel(NioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, Boolean.TRUE)
.option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
.handler(new LoggingHandler(LogLevel.INFO))
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new SelfDefineEncodeHandler());
pipeline.addLast(new SocketClientHandler());
}
});
ChannelFuture channelFuture = bootstrap.connect("localhost", 8899);
Channel channel = channelFuture.sync().channel();
//为刚刚建立的channel,初始化channel属性
Attribute<Map<Integer,Object>> attribute = channel.attr(ChannelUtils.DATA_MAP_ATTRIBUTEKEY);
ConcurrentHashMap<Integer, Object> dataMap = new ConcurrentHashMap<>();
attribute.set(dataMap);
return channel;
}
}复制代码
先使用构造方法,初始化channels数组,长度为4。NettyChannelPool类有两个关键的地方。
第一个是获取channel的时候必须加上锁。另一个是当获取到channel后,利用channel的属性,建立一个Map,后面须要利用这个Map创建消息序列号和callback程序的对应关系。
//初始化channel属性
Attribute<Map<Integer,Object>> attribute = channel.attr(ChannelUtils.DATA_MAP_ATTRIBUTEKEY);
ConcurrentHashMap<Integer, Object> dataMap = new ConcurrentHashMap<>();
attribute.set(dataMap);复制代码
这个map就是咱们上面看到的
Map的put的动做,就是在SocketClient类中的
ChannelUtils.putCallback2DataMap(channel,seq,callbackService);复制代码
执行的。客户端处理消息还须要两个hanlder辅助,一个是处理半包问题,一个是接收服务端的返回的消息。
SelfDefineEncodeHandler类用于处理半包消息
package nettyinaction.nettyclient.channelpool;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.util.List;
public class SelfDefineEncodeHandler extends ByteToMessageDecoder {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf bufferIn, List<Object> out) throws Exception {
if (bufferIn.readableBytes() < 4) {
return;
}
int beginIndex = bufferIn.readerIndex();
int length = bufferIn.readInt();
if (bufferIn.readableBytes() < length) {
bufferIn.readerIndex(beginIndex);
return;
}
bufferIn.readerIndex(beginIndex + 4 + length);
ByteBuf otherByteBufRef = bufferIn.slice(beginIndex, 4 + length);
otherByteBufRef.retain();
out.add(otherByteBufRef);
}
}复制代码
SocketClientHandler类用于接收服务端返回的消息,而且根据消息序列号获取对应的callback程序。
package nettyinaction.nettyclient.channelpool.client;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import nettyinaction.nettyclient.channelpool.ChannelUtils;
public class SocketClientHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Channel channel = ctx.channel();
ByteBuf responseBuf = (ByteBuf)msg;
responseBuf.markReaderIndex();
int length = responseBuf.readInt();
int seq = responseBuf.readInt();
responseBuf.resetReaderIndex();
//获取消息对应的callback
SocketClient.CallbackService callbackService = ChannelUtils.<SocketClient.CallbackService>removeCallback(channel, seq);
callbackService.receiveMessage(responseBuf);
}
}复制代码
到此客户端程序编写完毕。至于服务端的代码,则比较简单,这里直接贴上代码。
package nettyinaction.nettyclient.channelpool.server;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import nettyinaction.nettyclient.channelpool.SelfDefineEncodeHandler;
public class SocketServer {
public static void main(String[] args) throws InterruptedException {
EventLoopGroup parentGroup = new NioEventLoopGroup();
EventLoopGroup childGroup = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(parentGroup, childGroup)
.channel(NioServerSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new SelfDefineEncodeHandler());
pipeline.addLast(new BusinessServerHandler());
}
});
ChannelFuture channelFuture = serverBootstrap.bind(8899).sync();
channelFuture.channel().closeFuture().sync();
}
finally {
parentGroup.shutdownGracefully();
childGroup.shutdownGracefully();
}
}
}
package nettyinaction.nettyclient.channelpool.server;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import nettyinaction.nettyclient.channelpool.ChannelUtils;
public class BusinessServerHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Channel channel = ctx.channel();
ByteBuf buf = (ByteBuf)msg;
//一、读取消息长度
int length = buf.readInt();
//二、读取消息序列号
int seq = buf.readInt();
//三、读取消息头部
byte[] head = new byte[8];
buf.readBytes(head);
String headString = new String(head);
//四、读取消息体
byte[] body = new byte[4];
buf.readBytes(body);
String bodyString = new String(body);
//五、新创建一个缓存区,写入内容,返回给客户端
UnpooledByteBufAllocator allocator = new UnpooledByteBufAllocator(false);
ByteBuf responseBuf = allocator.buffer(20);
responseBuf.writeInt(ChannelUtils.MESSAGE_LENGTH);
responseBuf.writeInt(seq);
responseBuf.writeBytes(headString.getBytes());
responseBuf.writeBytes(bodyString.getBytes());
//六、将数据写回到客户端
channel.writeAndFlush(responseBuf);
}
}复制代码
运行服务端代码和客户端代码,指望的结果是
10个线程发送消息后,能从服务端获取到正确的对应的返回信息,这些信息不会发生错乱,各个线程都能拿到本身想要的结果,不会发生错读的状况。
运行后的结果以下
一、等待服务端的返回
因为 channel.writeAndFlush是异步的,必须有一种机制来让线程等待服务端返回结果。这里采用最原始的wait和notify方法。当writeAndFlush调用后,马上让当前线程wait住,放置在callbackservice对象的等待列表中,当服务器端返回消息时,客户端的SocketClientHandler类中的channelRead方法会被执行,解析完数据后,从channel的attr属性中获取DATA_MAP_ATTRIBUTEKEY 这个key对应的map。并根据解析出来的seq从map中获取事先放置好的callbackservice对象,执行它的receiveMessage方法。将receiveBuf这个存放结果的缓存区对象赋值到callbackservice的result属性中。并调用callbackservice对象的notify方法,唤醒wait在callbackservice对象的线程,让其继续往下执行。
二、产生消息序列号
int seq = IntegerFactory.getInstance().incrementAndGet();复制代码
为了演示的方便,这里是产生单服务器全局惟一的序列号。若是请求量大的话,就算是AtomicInteger是CAS操做,也会产生不少的竞争。建议产生channel级别的惟一序列号,下降竞争。只要保证在一个channel内的消息的序列号是不重复的便可。
至于其余的一些代码细节,读者能够本身再细看。