如何寫一個RPC框架(四):網絡通信之客戶端篇

在後續一段時間裏, 我會寫一系列文章來講述如何實現一個RPC框架(我已經實現了一個示例框架, 代碼在我的github上)。 這是系列第四篇文章, 主要講述了客戶端和服務器之間的網絡通信問題。

模型定義

我們需要自己來定義RPC通信所傳遞的內容的模型, 也就是RPCRequest和RPCResponse。

@Data
@Builder
public class RPCRequest {
    private String requestId;
    private String interfaceName;
    private String methodName;
    private Class<?>[] parameterTypes;
    private Object[] parameters;
}

@Data
public class RPCResponse {
    private String requestId;
    private Exception exception;
    private Object result;

    public boolean hasException() {
        return exception != null;
    }
}

這裏唯一需要說明一下的是requestId, 你可能會疑惑爲什麼我們需要這個東西。

原因是,發送請求的順序和收到返回的順序可能是不一致的, 因此我們需要有一個標識符來表明某一個返回所對應的請求是什麼。 具體怎麼利用這個字段, 本文後續會揭曉。

選擇NIO還是IO?

NIO和IO的選擇要視具體情況而定。對於我們的RPC框架來說, 一個服務可能與多個服務保持連接, 且每次通信只發送少量信息,那麼在這種情況下,NIO可能更適合一些。

我選擇使用Netty來簡化具體的實現, 自然地,我們就引入了Channel, Handler這些相關的概念。如果對Netty沒有任何瞭解, 建議先去簡單瞭解下相關內容再回過頭看這篇文章。

如何複用Channel

既然使用了NIO, 我們自然希望服務和服務之間是使用長連接進行通信, 而不是每個請求都重新創建一個channel。

那麼我們怎麼去複用channel呢? 既然我們已經通過前文的服務發現獲取到了service地址,並且與其建立了channel, 那麼我們自然就可以建立一個service地址與channel之間的映射關係, 每次拿到地址之後先判斷有沒有對應channel, 如果有的話就複用。這種映射關係我建立了ChannelManager去管理:

public class ChannelManager {
    /**
     * Singleton
     */
    private static ChannelManager channelManager;

    private ChannelManager(){}

    public static ChannelManager getInstance() {
        if (channelManager == null) {
            synchronized (ChannelManager.class) {
                if (channelManager == null) {
                    channelManager = new ChannelManager();
                }
            }
        }
        return channelManager;
    }

    // Service地址與channel之間的映射
    private Map<InetSocketAddress, Channel> channels = new ConcurrentHashMap<>();

    public Channel getChannel(InetSocketAddress inetSocketAddress) {
        Channel channel = channels.get(inetSocketAddress);
        if (null == channel) {
            EventLoopGroup group = new NioEventLoopGroup();
            try {
                Bootstrap bootstrap = new Bootstrap();
                bootstrap.group(group)
                        .channel(NioSocketChannel.class)
                        .handler(new RPCChannelInitializer())
                        .option(ChannelOption.SO_KEEPALIVE, true);

                channel = bootstrap.connect(inetSocketAddress.getHostName(), inetSocketAddress.getPort()).sync()
                        .channel();
                registerChannel(inetSocketAddress, channel);

                channel.closeFuture().addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        removeChannel(inetSocketAddress);
                    }
                });
            } catch (Exception e) {
                log.warn("Fail to get channel for address: {}", inetSocketAddress);
            }
        }
        return channel;
    }

    private void registerChannel(InetSocketAddress inetSocketAddress, Channel channel) {
        channels.put(inetSocketAddress, channel);
    }

    private void removeChannel(InetSocketAddress inetSocketAddress) {
        channels.remove(inetSocketAddress);
    }

}

有幾個地方需要解釋一下:

  1. 這裏用單例的目的是, 所有的proxybean都使用同一個ChannelManager。
  2. 創建Channel的過程很簡單,就是最普通的Netty客戶端創建channel的方法。
  3. 在channel被關閉(比如服務器端宕機了)後,需要從map中刪除對應的channel
  4. RPCChannelInitializer是整個過程的核心所在, 用於處理請求和返回的編解碼、 收到返回之後的回調等。 下文詳細說這個。

編解碼

上文的RPCChannelInitializer代碼如下:

private class RPCChannelInitializer extends ChannelInitializer<SocketChannel> {

        @Override
        protected void initChannel(SocketChannel ch) throws Exception {
            ChannelPipeline pipeline = ch.pipeline();
            pipeline.addLast(new RPCEncoder(RPCRequest.class, new ProtobufSerializer()));
            pipeline.addLast(new RPCDecoder(RPCResponse.class, new ProtobufSerializer()));
            pipeline.addLast(new RPCResponseHandler());  //先不用管這個
        }
    }

這裏的Encoder和Decoder都很簡單, 繼承了Netty中的codec,做一些簡單的byte數組和Object對象之間的轉換工作:

@AllArgsConstructor
public class RPCDecoder extends ByteToMessageDecoder {

    private Class<?> genericClass;
    private Serializer serializer;

    @Override
    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        if (in.readableBytes() < 4) {
            return;
        }
        in.markReaderIndex();
        int dataLength = in.readInt();
        if (in.readableBytes() < dataLength) {
            in.resetReaderIndex();
            return;
        }
        byte[] data = new byte[dataLength];
        in.readBytes(data);
        out.add(serializer.deserialize(data, genericClass));
    }
}

@AllArgsConstructor
public class RPCEncoder extends MessageToByteEncoder {

    private Class<?> genericClass;
    private Serializer serializer;

    @Override
    public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) throws Exception {
        if (genericClass.isInstance(in)) {
            byte[] data = serializer.serialize(in);
            out.writeInt(data.length);
            out.writeBytes(data);
        }
    }
}

這裏我選擇使用Protobuf序列化協議來做這件事(具體的ProtobufSerializer的實現因爲篇幅原因就不貼在這裏了, 需要的話請看項目的github)。 總的來說, 這一塊還是很簡單很好理解的。

發送請求與處理返回內容

請求的發送很簡單, 直接用channel.writeAndFlush(request) 就行了。

問題是, 發送之後, 怎麼獲取這個請求的返回呢?這裏,我引入了RPCResponseFuture和ResponseFutureManager來解決這個問題。

RPCResponseFuture實現了Future接口,所包含的值就是RPCResponse, 每個RPCResponseFuture都與一個requestId相關聯, 除此之外, 還利用了CountDownLatch來做get方法的阻塞處理:

public class RPCResponseFuture implements Future<Object> {
    private String requestId;

    private RPCResponse response;

    CountDownLatch latch = new CountDownLatch(1);

    public RPCResponseFuture(String requestId) {
        this.requestId = requestId;
    }

    public void done(RPCResponse response) {
        this.response = response;
        latch.countDown();
    }

    @Override
    public RPCResponse get() throws InterruptedException, ExecutionException {
        try {
            latch.await();
        } catch (InterruptedException e) {
            log.error(e.getMessage());
        }
        return response;
    }

  // ....
}

既然每個請求都會產生一個ResponseFuture, 那麼自然要有一個Manager來管理這些future:

public class ResponseFutureManager {
    /**
     * Singleton
     */
    private static ResponseFutureManager rpcFutureManager;

    private ResponseFutureManager(){}

    public static ResponseFutureManager getInstance() {
        if (rpcFutureManager == null) {
            synchronized (ChannelManager.class) {
                if (rpcFutureManager == null) {
                    rpcFutureManager = new ResponseFutureManager();
                }
            }
        }
        return rpcFutureManager;
    }

    private ConcurrentHashMap<String, RPCResponseFuture> rpcFutureMap = new ConcurrentHashMap<>();

    public void registerFuture(RPCResponseFuture rpcResponseFuture) {
        rpcFutureMap.put(rpcResponseFuture.getRequestId(), rpcResponseFuture);
    }

    public void futureDone(RPCResponse response) {
        rpcFutureMap.remove(response.getRequestId()).done(response);
    }
}

ResponseFutureManager很好看懂, 就是提供了註冊future、完成future的接口。

現在我們再回過頭看RPCChannelInitializer中的RPCResponseHandler就很好理解了: 拿到返回值, 把對應的ResponseFuture標記成done就可以了!

/**
* 處理收到返回後的回調
*/
    private class RPCResponseHandler extends SimpleChannelInboundHandler<RPCResponse> {

        @Override
        public void channelRead0(ChannelHandlerContext ctx, RPCResponse response) throws Exception {
            log.debug("Get response: {}", response);
            ResponseFutureManager.getInstance().futureDone(response);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            log.warn("RPC request exception: {}", cause);
        }
    }

前文的FactoryBean的邏輯填充

到這裏,我們已經實現了客戶端的網絡通信, 現在只需要把它加到前文的FactoryBean的doInvoke方法就好了!

    private Object doInvoke(Object proxy, Method method, Object[] args) throws Throwable {
        String targetServiceName = type.getName();

        // Create request
        RPCRequest request = RPCRequest.builder()
                .requestId(generateRequestId(targetServiceName))
                .interfaceName(method.getDeclaringClass().getName())
                .methodName(method.getName())
                .parameters(args)
                .parameterTypes(method.getParameterTypes()).build();

        // Get service address
        InetSocketAddress serviceAddress = getServiceAddress(targetServiceName);

        // Get channel by service address
        Channel channel = ChannelManager.getInstance().getChannel(serviceAddress);
        if (null == channel) {
            throw new RuntimeException("Cann't get channel for address" + serviceAddress);
        }

        // Send request
        RPCResponse response = sendRequest(channel, request);
        if (response == null) {
            throw new RuntimeException("response is null");
        }
        if (response.hasException()) {
            throw response.getException();
        } else {
            return response.getResult();
        }
    }

    private String generateRequestId(String targetServiceName) {
        return targetServiceName + "-" + UUID.randomUUID().toString();
    }

    private InetSocketAddress getServiceAddress(String targetServiceName) {
        String serviceAddress = "";
        if (serviceDiscovery != null) {
            serviceAddress = serviceDiscovery.discover(targetServiceName);
            log.debug("Get address: {} for service: {}", serviceAddress, targetServiceName);
        }
        if (StringUtils.isEmpty(serviceAddress)) {
            throw new RuntimeException("server address is empty");
        }
        String[] array = StringUtils.split(serviceAddress, ":");
        String host = array[0];
        int port = Integer.parseInt(array[1]);
        return new InetSocketAddress(host, port);
    }

    private RPCResponse sendRequest(Channel channel, RPCRequest request) {
        log.debug("Send request, channel: {}, request: {}", channel, request);
        CountDownLatch latch = new CountDownLatch(1);
        RPCResponseFuture rpcResponseFuture = new RPCResponseFuture(request.getRequestId());
        ResponseFutureManager.getInstance().registerFuture(rpcResponseFuture);
        channel.writeAndFlush(request).addListener((ChannelFutureListener) future -> {
            log.debug("Request sent.");
            latch.countDown();
        });
        try {
            latch.await();
        } catch (InterruptedException e) {
            log.error(e.getMessage());
        }

        try {
            return rpcResponseFuture.get(1, TimeUnit.SECONDS);
        } catch (Exception e) {
            log.warn("Exception:", e);
            return null;
        }
    }

就這樣, 一個簡單的RPC客戶端就實現了。 完整代碼請看我的github

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章