diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClient.java index 50e132b3d059e..eda27b622d391 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClient.java @@ -49,6 +49,7 @@ import org.apache.flink.runtime.util.AtomicDisposableReferenceCounter; import com.ibm.disni.verbs.IbvWC; +import com.ibm.disni.verbs.StatefulVerbCall; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -218,6 +219,7 @@ class PartitionReaderClient implements Runnable { private final Map receivedBuffers = new HashMap<>(); // ArrayDeque receivedBuffers = new ArrayDeque<>(); Map inFlight = new HashMap(); + Map>> inFlightVerbs = new HashMap(); // private long workRequestId; public PartitionReaderClient(final ResultPartitionID partitionId, @@ -244,7 +246,7 @@ private int postBuffers(int credit) { receiveBuffer = (NetworkBuffer) inputChannel.requestBuffer(); if (receiveBuffer != null) { long id = clientEndpoint.workRequestId.incrementAndGet(); - RdmaSendReceiveUtil.postReceiveReqWithChannelBuf(clientEndpoint,id, receiveBuffer); + RdmaSendReceiveUtil.postReceiveReqWithChannelBuf(clientEndpoint,id, receiveBuffer,inFlightVerbs); receivedBuffers.put(id,receiveBuffer); } else { LOG.error("Buffer from the channel is null"); @@ -272,7 +274,7 @@ public void run() { try { buf = msg.write(clientEndpoint.getNettyBufferpool()); clientEndpoint.getSendBuffer().put(buf.nioBuffer()); - RdmaSendReceiveUtil.postSendReq(clientEndpoint, clientEndpoint.workRequestId.incrementAndGet()); + RdmaSendReceiveUtil.postSendReq(clientEndpoint, clientEndpoint.workRequestId.incrementAndGet(),inFlightVerbs); } catch (Exception ioe) { LOG.error("Failed to serialize partition request"); return; @@ -297,7 +299,7 @@ public void run() { clientEndpoint.getSendBuffer().put(message.nioBuffer()); long workID = clientEndpoint.workRequestId.incrementAndGet(); inFlight.put(workID,message); - RdmaSendReceiveUtil.postSendReq(clientEndpoint, workID); + RdmaSendReceiveUtil.postSendReq(clientEndpoint, workID,inFlightVerbs); } else { // LOG.info("No credit available on channel {}",availableCredit,inputChannel); // wait for the credit to be available, otherwise connection stucks in blocking @@ -309,6 +311,7 @@ public void run() { } IbvWC wc = clientEndpoint.getWcEvents().take(); + inFlightVerbs.remove(wc.getWr_id()).free(); // LOG.info("took client event with wr_id {} on endpoint {}", wc.getWr_id(), clientEndpoint.getEndpointStr()); if (IbvWC.IbvWcOpcode.valueOf(wc.getOpcode()) == IbvWC.IbvWcOpcode.IBV_WC_RECV) { if (wc.getStatus() != IbvWC.IbvWcStatus.IBV_WC_SUCCESS.ordinal()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClientFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClientFactory.java index b9aea1c4413b2..3e12932008297 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClientFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/PartitionRequestClientFactory.java @@ -78,7 +78,7 @@ PartitionRequestClientIf createPartitionRequestClient(ConnectionID connectionId, clientEndpoints.put(connectionId, endpoint); PartitionRequestClient client = new PartitionRequestClient( endpoint, clientHandler, connectionId, this); - rdmaClient.start(connectionId.getAddress()); +// rdmaClient.start(connectionId.getAddress()); clients.putIfAbsent(connectionId, client); LOG.info("creating partition client {} for connection id {}", endpoint.getEndpointStr(), connectionId.toString()); // } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaSendReceiveUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaSendReceiveUtil.java index 7ab7d1d2e76b0..5d3218b670628 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaSendReceiveUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaSendReceiveUtil.java @@ -22,26 +22,30 @@ import com.ibm.disni.verbs.IbvRecvWR; import com.ibm.disni.verbs.IbvSendWR; import com.ibm.disni.verbs.IbvSge; +import com.ibm.disni.verbs.SVCPostRecv; +import com.ibm.disni.verbs.SVCPostSend; +import com.ibm.disni.verbs.StatefulVerbCall; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import sun.nio.ch.DirectBuffer; import java.io.IOException; import java.util.LinkedList; +import java.util.Map; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; import org.apache.flink.runtime.io.network.buffer.ReadOnlySlicedNetworkBuffer; -import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; public class RdmaSendReceiveUtil { private static final Logger LOG = LoggerFactory.getLogger(RdmaSendReceiveUtil.class); public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, long workReqId, NettyMessage - .BufferResponse response) throws IOException { + .BufferResponse response,Map>> + inFlightVerbs) throws IOException { if (endpoint instanceof RdmaShuffleServerEndpoint) { RdmaShuffleServerEndpoint clientEndpoint = (RdmaShuffleServerEndpoint) endpoint; @@ -70,7 +74,8 @@ public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, lon } // header is at the end of segment int start = segment.size() - RdmaConnectionManager.DATA_MSG_HEADER_SIZE; -// LOG.info("SRUtil: Header start address {}, end address {} buffer length {} sent magic {} byte order {}", segment.getAddress() + start, +// LOG.info("SRUtil: Header start address {}, end address {} buffer length {} sent magic {} byte order {}", +// segment.getAddress() + start, // segment.getAddress() + segment.size(),buf.writerIndex(),segment.getIntBigEndian(start+4),buf.order()); IbvSge headerSGE = new IbvSge(); @@ -81,7 +86,7 @@ public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, lon // actual data IbvSge dataSGE = new IbvSge(); dataSGE.setAddr(dataAddress); - dataSGE.setLength(dataLen); + dataSGE.setLength(dataLen); // dataSGE.setAddr(segment.getAddress()); // dataSGE.setLength(segment.size()); dataSGE.setLkey(clientEndpoint.getRegisteredMRs().get(segment.getAddress()).getLkey()); @@ -118,11 +123,16 @@ public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, lon LinkedList sendWRs = new LinkedList<>(); sendWRs.add(sendWR); - clientEndpoint.postSend(sendWRs).execute().free(); + SVCPostSend sendVerb = clientEndpoint.postSend(sendWRs).execute(); + synchronized (inFlightVerbs) { + inFlightVerbs.put(workReqId, sendVerb); + } } } - public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId) throws IOException { + public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId, Map>> + inFlightVerbs) throws IOException { if (endpoint instanceof RdmaShuffleServerEndpoint) { RdmaShuffleServerEndpoint clientEndpoint = (RdmaShuffleServerEndpoint) endpoint; @@ -148,7 +158,10 @@ public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId) thro LinkedList sendWRs = new LinkedList<>(); sendWRs.add(sendWR); - clientEndpoint.postSend(sendWRs).execute().free(); + SVCPostSend sendVerb = clientEndpoint.postSend(sendWRs).execute(); + synchronized (inFlightVerbs) { + inFlightVerbs.put(workReqId, sendVerb); + } } else if (endpoint instanceof RdmaShuffleClientEndpoint) { RdmaShuffleClientEndpoint clientEndpoint = (RdmaShuffleClientEndpoint) endpoint; // LOG.info("posting client send wr_id " + workReqId+ " against src: " + endpoint.getSrcAddr() + " dest: " @@ -171,12 +184,16 @@ public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId) thro sendWR.setSend_flags(IbvSendWR.IBV_SEND_SIGNALED); LinkedList sendWRs = new LinkedList<>(); - sendWRs.add(sendWR); - clientEndpoint.postSend(sendWRs).execute().free(); + SVCPostSend sendVerb = clientEndpoint.postSend(sendWRs).execute(); + synchronized (inFlightVerbs) { + inFlightVerbs.put(workReqId, sendVerb); + } } } - public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId) throws IOException { + public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId, Map>> + inFlightVerbs) throws IOException { if (endpoint instanceof RdmaShuffleServerEndpoint) { // LOG.info("posting server receive wr_id " + workReqId + " against src: " + endpoint.getSrcAddr() + " dest: @@ -195,7 +212,10 @@ public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId) t LinkedList recvWRs = new LinkedList<>(); recvWRs.add(recvWR); - endpoint.postRecv(recvWRs).execute().free(); + SVCPostRecv recvVerb = endpoint.postRecv(recvWRs).execute(); + synchronized (inFlightVerbs) { + inFlightVerbs.put(workReqId, recvVerb); + } } else if (endpoint instanceof RdmaShuffleClientEndpoint) { // LOG.info("posting client receive wr_id " + workReqId + " against src: " + endpoint.getSrcAddr() + " dest: // " +endpoint.getDstAddr()); @@ -213,11 +233,16 @@ public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId) t LinkedList recvWRs = new LinkedList<>(); recvWRs.add(recvWR); - endpoint.postRecv(recvWRs).execute().free(); + SVCPostRecv recvVerb = endpoint.postRecv(recvWRs).execute(); + synchronized (inFlightVerbs) { + inFlightVerbs.put(workReqId, recvVerb); + } } } - public static void postReceiveReqWithChannelBuf(RdmaActiveEndpoint endpoint, long workReqId, ByteBuf buffer) + public static void postReceiveReqWithChannelBuf(RdmaActiveEndpoint endpoint, long workReqId, ByteBuf buffer, + Map>> + inFlightVerbs) throws IOException { if (endpoint instanceof RdmaShuffleClientEndpoint) { @@ -237,7 +262,10 @@ public static void postReceiveReqWithChannelBuf(RdmaActiveEndpoint endpoint, lon LinkedList recvWRs = new LinkedList<>(); recvWRs.add(recvWR); - endpoint.postRecv(recvWRs).execute().free(); + SVCPostRecv recvVerb = endpoint.postRecv(recvWRs).execute(); + synchronized (inFlightVerbs) { + inFlightVerbs.put(workReqId, recvVerb); + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaServerRequestHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaServerRequestHandler.java index 5c78270ffa0ed..42764187d6d71 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaServerRequestHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/rdma/RdmaServerRequestHandler.java @@ -4,6 +4,7 @@ import com.ibm.disni.RdmaServerEndpoint; import com.ibm.disni.verbs.IbvMr; import com.ibm.disni.verbs.IbvWC; +import com.ibm.disni.verbs.StatefulVerbCall; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,7 +53,6 @@ public void run() { while (!stopped) { try { RdmaShuffleServerEndpoint clientEndpoint = serverEndpoint.accept(); - RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, 0); clientEndpoint.setRegisteredMRs(registerdMRs); // TODO (venkat): Handle accepted connection, not using thread pool as it is only proto-type with 4 // servers @@ -95,11 +95,11 @@ private NettyMessage readPartition(NetworkSequenceViewReader reader) throws } else { // This channel was now removed from the available reader queue. // We re-add it into the queue if it is still available - if (next.moreAvailable()) { - reader.setRegisteredAsAvailable(true); - } else { - reader.setRegisteredAsAvailable(false); - } + if (next.moreAvailable()) { + reader.setRegisteredAsAvailable(true); + } else { + reader.setRegisteredAsAvailable(false); + } NettyMessage.BufferResponse msg = new NettyMessage.BufferResponse( next.buffer(), reader.getSequenceNumber(), @@ -111,9 +111,12 @@ private NettyMessage readPartition(NetworkSequenceViewReader reader) throws private class HandleClientConnection implements Runnable { RdmaShuffleServerEndpoint clientEndpoint; + Map>> inFlightVerbs = new HashMap(); - HandleClientConnection(RdmaShuffleServerEndpoint clientEndpoint) { + + HandleClientConnection(RdmaShuffleServerEndpoint clientEndpoint) throws IOException { this.clientEndpoint = clientEndpoint; + RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, 0, inFlightVerbs); } private final ConcurrentHashMap inFlight = new ConcurrentHashMap<>(); @@ -136,6 +139,9 @@ public void run() { // LOG.info("Server: Did not get expected wr_id {} on endpoint {}", wc.getWr_id(), clientEndpoint // .getEndpointStr()); // } + synchronized (inFlightVerbs){ + inFlightVerbs.remove(wc.getWr_id()).free(); + } if (IbvWC.IbvWcOpcode.valueOf(wc.getOpcode()) == IbvWC.IbvWcOpcode.IBV_WC_RECV) { if (wc.getStatus() != IbvWC.IbvWcStatus.IBV_WC_SUCCESS.ordinal()) { LOG.error("Receive posting failed. reposting new receive request"); @@ -162,15 +168,15 @@ public void run() { partitionRequest.queueIndex); // we need to post receive for next message. for example credit // TODO: We should do executor service here - new Thread(new RDMAWriter(reader, clientEndpoint, inFlight)).start(); + new Thread(new RDMAWriter(reader, clientEndpoint, inFlight,inFlightVerbs)).start(); RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId - .incrementAndGet()); + .incrementAndGet(), inFlightVerbs); // TODO(venkat): do something better here, we should not poll reader } else if (msgClazz == NettyMessage.TaskEventRequest.class) { NettyMessage.TaskEventRequest request = (NettyMessage.TaskEventRequest) clientRequest; LOG.error("Unhandled request type TaskEventRequest"); RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId - .incrementAndGet()); // post next + .incrementAndGet(), inFlightVerbs); // post next // receive // TODO (venkat): Handle it if (!taskEventDispatcher.publish(request.partitionId, request.event)) { @@ -182,7 +188,7 @@ public void run() { clientRequest; LOG.error("Unhandled request type CancelPartitionRequest"); RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId - .incrementAndGet()); // post next + .incrementAndGet(), inFlightVerbs); // post next // receive // TODO (venkat): Handle it // outboundQueue.cancel(request.receiverId); @@ -194,13 +200,13 @@ public void run() { reader.addCredit(request.credit); } RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId - .incrementAndGet()); // post next + .incrementAndGet(), inFlightVerbs); // post next // receive // TODO (venkat): Handle it // outboundQueue.addCredit(request.receiverId, request.credit); } else { RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId - .incrementAndGet()); // post next + .incrementAndGet(), inFlightVerbs); // post next // receive LOG.warn("Received unexpected client request: {}", clientRequest); } @@ -245,12 +251,16 @@ private class RDMAWriter implements Runnable { private NetworkSequenceViewReader reader; private RdmaShuffleServerEndpoint clientEndpoint; private ConcurrentHashMap inFlight; + private Map>> inFlightVerbs; + public RDMAWriter(NetworkSequenceViewReader reader, RdmaShuffleServerEndpoint clientEndpoint, - ConcurrentHashMap inFlight) { + ConcurrentHashMap inFlight, Map>> inFlightVerbs) { this.reader = reader; this.clientEndpoint = clientEndpoint; this.inFlight = inFlight; + this.inFlightVerbs = inFlightVerbs; } @Override @@ -296,12 +306,12 @@ public void run() { // (NettyMessage.BufferResponse) response).getBuffer().memoryAddress()); inFlight.put(workRequestId, (NettyMessage.BufferResponse) response); RdmaSendReceiveUtil.postSendReqForBufferResponse(clientEndpoint, workRequestId, - (NettyMessage.BufferResponse) response); + (NettyMessage.BufferResponse) response,inFlightVerbs); } } else { clientEndpoint.getSendBuffer().put(response.write(bufferPool).nioBuffer()); RdmaSendReceiveUtil.postSendReq(clientEndpoint, clientEndpoint.workRequestId - .incrementAndGet()); + .incrementAndGet(),inFlightVerbs); } } else { synchronized (reader) {