Skip to content

Commit

Permalink
Free should not be called untill the WR is complete
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatsc committed Nov 1, 2019
1 parent 9a22e4b commit 4d7516c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.flink.runtime.util.AtomicDisposableReferenceCounter;

import com.ibm.disni.verbs.IbvWC;
import com.ibm.disni.verbs.StatefulVerbCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -218,6 +219,7 @@ class PartitionReaderClient implements Runnable {
private final Map<Long,ByteBuf> receivedBuffers = new HashMap<>();
// ArrayDeque<ByteBuf> receivedBuffers = new ArrayDeque<>();
Map<Long,ByteBuf> inFlight = new HashMap<Long,ByteBuf>();
Map<Long,StatefulVerbCall<? extends StatefulVerbCall<?>>> inFlightVerbs = new HashMap();
// private long workRequestId;

public PartitionReaderClient(final ResultPartitionID partitionId,
Expand All @@ -244,7 +246,7 @@ private int postBuffers(int credit) {
receiveBuffer = (NetworkBuffer) inputChannel.requestBuffer();
if (receiveBuffer != null) {
long id = clientEndpoint.workRequestId.incrementAndGet();
RdmaSendReceiveUtil.postReceiveReqWithChannelBuf(clientEndpoint,id, receiveBuffer);
RdmaSendReceiveUtil.postReceiveReqWithChannelBuf(clientEndpoint,id, receiveBuffer,inFlightVerbs);
receivedBuffers.put(id,receiveBuffer);
} else {
LOG.error("Buffer from the channel is null");
Expand Down Expand Up @@ -272,7 +274,7 @@ public void run() {
try {
buf = msg.write(clientEndpoint.getNettyBufferpool());
clientEndpoint.getSendBuffer().put(buf.nioBuffer());
RdmaSendReceiveUtil.postSendReq(clientEndpoint, clientEndpoint.workRequestId.incrementAndGet());
RdmaSendReceiveUtil.postSendReq(clientEndpoint, clientEndpoint.workRequestId.incrementAndGet(),inFlightVerbs);
} catch (Exception ioe) {
LOG.error("Failed to serialize partition request");
return;
Expand All @@ -297,7 +299,7 @@ public void run() {
clientEndpoint.getSendBuffer().put(message.nioBuffer());
long workID = clientEndpoint.workRequestId.incrementAndGet();
inFlight.put(workID,message);
RdmaSendReceiveUtil.postSendReq(clientEndpoint, workID);
RdmaSendReceiveUtil.postSendReq(clientEndpoint, workID,inFlightVerbs);
} else {
// LOG.info("No credit available on channel {}",availableCredit,inputChannel);
// wait for the credit to be available, otherwise connection stucks in blocking
Expand All @@ -309,6 +311,7 @@ public void run() {
}

IbvWC wc = clientEndpoint.getWcEvents().take();
inFlightVerbs.remove(wc.getWr_id()).free();
// LOG.info("took client event with wr_id {} on endpoint {}", wc.getWr_id(), clientEndpoint.getEndpointStr());
if (IbvWC.IbvWcOpcode.valueOf(wc.getOpcode()) == IbvWC.IbvWcOpcode.IBV_WC_RECV) {
if (wc.getStatus() != IbvWC.IbvWcStatus.IBV_WC_SUCCESS.ordinal()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ PartitionRequestClientIf createPartitionRequestClient(ConnectionID connectionId,
clientEndpoints.put(connectionId, endpoint);
PartitionRequestClient client = new PartitionRequestClient(
endpoint, clientHandler, connectionId, this);
rdmaClient.start(connectionId.getAddress());
// rdmaClient.start(connectionId.getAddress());
clients.putIfAbsent(connectionId, client);
LOG.info("creating partition client {} for connection id {}", endpoint.getEndpointStr(), connectionId.toString());
// }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,30 @@
import com.ibm.disni.verbs.IbvRecvWR;
import com.ibm.disni.verbs.IbvSendWR;
import com.ibm.disni.verbs.IbvSge;
import com.ibm.disni.verbs.SVCPostRecv;
import com.ibm.disni.verbs.SVCPostSend;
import com.ibm.disni.verbs.StatefulVerbCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sun.nio.ch.DirectBuffer;

import java.io.IOException;
import java.util.LinkedList;
import java.util.Map;

import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;

import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.ReadOnlySlicedNetworkBuffer;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;

public class RdmaSendReceiveUtil {
private static final Logger LOG = LoggerFactory.getLogger(RdmaSendReceiveUtil.class);

public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, long workReqId, NettyMessage
.BufferResponse response) throws IOException {
.BufferResponse response,Map<Long, StatefulVerbCall<? extends
StatefulVerbCall<?>>>
inFlightVerbs) throws IOException {

if (endpoint instanceof RdmaShuffleServerEndpoint) {
RdmaShuffleServerEndpoint clientEndpoint = (RdmaShuffleServerEndpoint) endpoint;
Expand Down Expand Up @@ -70,7 +74,8 @@ public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, lon
}
// header is at the end of segment
int start = segment.size() - RdmaConnectionManager.DATA_MSG_HEADER_SIZE;
// LOG.info("SRUtil: Header start address {}, end address {} buffer length {} sent magic {} byte order {}", segment.getAddress() + start,
// LOG.info("SRUtil: Header start address {}, end address {} buffer length {} sent magic {} byte order {}",
// segment.getAddress() + start,
// segment.getAddress() + segment.size(),buf.writerIndex(),segment.getIntBigEndian(start+4),buf.order());

IbvSge headerSGE = new IbvSge();
Expand All @@ -81,7 +86,7 @@ public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, lon
// actual data
IbvSge dataSGE = new IbvSge();
dataSGE.setAddr(dataAddress);
dataSGE.setLength(dataLen);
dataSGE.setLength(dataLen);
// dataSGE.setAddr(segment.getAddress());
// dataSGE.setLength(segment.size());
dataSGE.setLkey(clientEndpoint.getRegisteredMRs().get(segment.getAddress()).getLkey());
Expand Down Expand Up @@ -118,11 +123,16 @@ public static void postSendReqForBufferResponse(RdmaActiveEndpoint endpoint, lon

LinkedList<IbvSendWR> sendWRs = new LinkedList<>();
sendWRs.add(sendWR);
clientEndpoint.postSend(sendWRs).execute().free();
SVCPostSend sendVerb = clientEndpoint.postSend(sendWRs).execute();
synchronized (inFlightVerbs) {
inFlightVerbs.put(workReqId, sendVerb);
}
}
}

public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId) throws IOException {
public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId, Map<Long, StatefulVerbCall<? extends
StatefulVerbCall<?>>>
inFlightVerbs) throws IOException {

if (endpoint instanceof RdmaShuffleServerEndpoint) {
RdmaShuffleServerEndpoint clientEndpoint = (RdmaShuffleServerEndpoint) endpoint;
Expand All @@ -148,7 +158,10 @@ public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId) thro

LinkedList<IbvSendWR> sendWRs = new LinkedList<>();
sendWRs.add(sendWR);
clientEndpoint.postSend(sendWRs).execute().free();
SVCPostSend sendVerb = clientEndpoint.postSend(sendWRs).execute();
synchronized (inFlightVerbs) {
inFlightVerbs.put(workReqId, sendVerb);
}
} else if (endpoint instanceof RdmaShuffleClientEndpoint) {
RdmaShuffleClientEndpoint clientEndpoint = (RdmaShuffleClientEndpoint) endpoint;
// LOG.info("posting client send wr_id " + workReqId+ " against src: " + endpoint.getSrcAddr() + " dest: "
Expand All @@ -171,12 +184,16 @@ public static void postSendReq(RdmaActiveEndpoint endpoint, long workReqId) thro
sendWR.setSend_flags(IbvSendWR.IBV_SEND_SIGNALED);

LinkedList<IbvSendWR> sendWRs = new LinkedList<>();
sendWRs.add(sendWR);
clientEndpoint.postSend(sendWRs).execute().free();
SVCPostSend sendVerb = clientEndpoint.postSend(sendWRs).execute();
synchronized (inFlightVerbs) {
inFlightVerbs.put(workReqId, sendVerb);
}
}
}

public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId) throws IOException {
public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId, Map<Long, StatefulVerbCall<?
extends StatefulVerbCall<?>>>
inFlightVerbs) throws IOException {

if (endpoint instanceof RdmaShuffleServerEndpoint) {
// LOG.info("posting server receive wr_id " + workReqId + " against src: " + endpoint.getSrcAddr() + " dest:
Expand All @@ -195,7 +212,10 @@ public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId) t

LinkedList<IbvRecvWR> recvWRs = new LinkedList<>();
recvWRs.add(recvWR);
endpoint.postRecv(recvWRs).execute().free();
SVCPostRecv recvVerb = endpoint.postRecv(recvWRs).execute();
synchronized (inFlightVerbs) {
inFlightVerbs.put(workReqId, recvVerb);
}
} else if (endpoint instanceof RdmaShuffleClientEndpoint) {
// LOG.info("posting client receive wr_id " + workReqId + " against src: " + endpoint.getSrcAddr() + " dest:
// " +endpoint.getDstAddr());
Expand All @@ -213,11 +233,16 @@ public static void postReceiveReq(RdmaActiveEndpoint endpoint, long workReqId) t

LinkedList<IbvRecvWR> recvWRs = new LinkedList<>();
recvWRs.add(recvWR);
endpoint.postRecv(recvWRs).execute().free();
SVCPostRecv recvVerb = endpoint.postRecv(recvWRs).execute();
synchronized (inFlightVerbs) {
inFlightVerbs.put(workReqId, recvVerb);
}
}
}

public static void postReceiveReqWithChannelBuf(RdmaActiveEndpoint endpoint, long workReqId, ByteBuf buffer)
public static void postReceiveReqWithChannelBuf(RdmaActiveEndpoint endpoint, long workReqId, ByteBuf buffer,
Map<Long, StatefulVerbCall<? extends StatefulVerbCall<?>>>
inFlightVerbs)
throws IOException {

if (endpoint instanceof RdmaShuffleClientEndpoint) {
Expand All @@ -237,7 +262,10 @@ public static void postReceiveReqWithChannelBuf(RdmaActiveEndpoint endpoint, lon

LinkedList<IbvRecvWR> recvWRs = new LinkedList<>();
recvWRs.add(recvWR);
endpoint.postRecv(recvWRs).execute().free();
SVCPostRecv recvVerb = endpoint.postRecv(recvWRs).execute();
synchronized (inFlightVerbs) {
inFlightVerbs.put(workReqId, recvVerb);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.ibm.disni.RdmaServerEndpoint;
import com.ibm.disni.verbs.IbvMr;
import com.ibm.disni.verbs.IbvWC;
import com.ibm.disni.verbs.StatefulVerbCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -52,7 +53,6 @@ public void run() {
while (!stopped) {
try {
RdmaShuffleServerEndpoint clientEndpoint = serverEndpoint.accept();
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, 0);
clientEndpoint.setRegisteredMRs(registerdMRs);
// TODO (venkat): Handle accepted connection, not using thread pool as it is only proto-type with 4
// servers
Expand Down Expand Up @@ -95,11 +95,11 @@ private NettyMessage readPartition(NetworkSequenceViewReader reader) throws
} else {
// This channel was now removed from the available reader queue.
// We re-add it into the queue if it is still available
if (next.moreAvailable()) {
reader.setRegisteredAsAvailable(true);
} else {
reader.setRegisteredAsAvailable(false);
}
if (next.moreAvailable()) {
reader.setRegisteredAsAvailable(true);
} else {
reader.setRegisteredAsAvailable(false);
}
NettyMessage.BufferResponse msg = new NettyMessage.BufferResponse(
next.buffer(),
reader.getSequenceNumber(),
Expand All @@ -111,9 +111,12 @@ private NettyMessage readPartition(NetworkSequenceViewReader reader) throws

private class HandleClientConnection implements Runnable {
RdmaShuffleServerEndpoint clientEndpoint;
Map<Long, StatefulVerbCall<? extends StatefulVerbCall<?>>> inFlightVerbs = new HashMap();

HandleClientConnection(RdmaShuffleServerEndpoint clientEndpoint) {

HandleClientConnection(RdmaShuffleServerEndpoint clientEndpoint) throws IOException {
this.clientEndpoint = clientEndpoint;
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, 0, inFlightVerbs);
}

private final ConcurrentHashMap<Long, NettyMessage.BufferResponse> inFlight = new ConcurrentHashMap<>();
Expand All @@ -136,6 +139,9 @@ public void run() {
// LOG.info("Server: Did not get expected wr_id {} on endpoint {}", wc.getWr_id(), clientEndpoint
// .getEndpointStr());
// }
synchronized (inFlightVerbs){
inFlightVerbs.remove(wc.getWr_id()).free();
}
if (IbvWC.IbvWcOpcode.valueOf(wc.getOpcode()) == IbvWC.IbvWcOpcode.IBV_WC_RECV) {
if (wc.getStatus() != IbvWC.IbvWcStatus.IBV_WC_SUCCESS.ordinal()) {
LOG.error("Receive posting failed. reposting new receive request");
Expand All @@ -162,15 +168,15 @@ public void run() {
partitionRequest.queueIndex);
// we need to post receive for next message. for example credit
// TODO: We should do executor service here
new Thread(new RDMAWriter(reader, clientEndpoint, inFlight)).start();
new Thread(new RDMAWriter(reader, clientEndpoint, inFlight,inFlightVerbs)).start();
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId
.incrementAndGet());
.incrementAndGet(), inFlightVerbs);
// TODO(venkat): do something better here, we should not poll reader
} else if (msgClazz == NettyMessage.TaskEventRequest.class) {
NettyMessage.TaskEventRequest request = (NettyMessage.TaskEventRequest) clientRequest;
LOG.error("Unhandled request type TaskEventRequest");
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId
.incrementAndGet()); // post next
.incrementAndGet(), inFlightVerbs); // post next
// receive
// TODO (venkat): Handle it
if (!taskEventDispatcher.publish(request.partitionId, request.event)) {
Expand All @@ -182,7 +188,7 @@ public void run() {
clientRequest;
LOG.error("Unhandled request type CancelPartitionRequest");
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId
.incrementAndGet()); // post next
.incrementAndGet(), inFlightVerbs); // post next
// receive
// TODO (venkat): Handle it
// outboundQueue.cancel(request.receiverId);
Expand All @@ -194,13 +200,13 @@ public void run() {
reader.addCredit(request.credit);
}
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId
.incrementAndGet()); // post next
.incrementAndGet(), inFlightVerbs); // post next
// receive
// TODO (venkat): Handle it
// outboundQueue.addCredit(request.receiverId, request.credit);
} else {
RdmaSendReceiveUtil.postReceiveReq(clientEndpoint, clientEndpoint.workRequestId
.incrementAndGet()); // post next
.incrementAndGet(), inFlightVerbs); // post next
// receive
LOG.warn("Received unexpected client request: {}", clientRequest);
}
Expand Down Expand Up @@ -245,12 +251,16 @@ private class RDMAWriter implements Runnable {
private NetworkSequenceViewReader reader;
private RdmaShuffleServerEndpoint clientEndpoint;
private ConcurrentHashMap<Long, NettyMessage.BufferResponse> inFlight;
private Map<Long, StatefulVerbCall<? extends StatefulVerbCall<?>>> inFlightVerbs;


public RDMAWriter(NetworkSequenceViewReader reader, RdmaShuffleServerEndpoint clientEndpoint,
ConcurrentHashMap<Long, NettyMessage.BufferResponse> inFlight) {
ConcurrentHashMap<Long, NettyMessage.BufferResponse> inFlight, Map<Long, StatefulVerbCall<?
extends StatefulVerbCall<?>>> inFlightVerbs) {
this.reader = reader;
this.clientEndpoint = clientEndpoint;
this.inFlight = inFlight;
this.inFlightVerbs = inFlightVerbs;
}

@Override
Expand Down Expand Up @@ -296,12 +306,12 @@ public void run() {
// (NettyMessage.BufferResponse) response).getBuffer().memoryAddress());
inFlight.put(workRequestId, (NettyMessage.BufferResponse) response);
RdmaSendReceiveUtil.postSendReqForBufferResponse(clientEndpoint, workRequestId,
(NettyMessage.BufferResponse) response);
(NettyMessage.BufferResponse) response,inFlightVerbs);
}
} else {
clientEndpoint.getSendBuffer().put(response.write(bufferPool).nioBuffer());
RdmaSendReceiveUtil.postSendReq(clientEndpoint, clientEndpoint.workRequestId
.incrementAndGet());
.incrementAndGet(),inFlightVerbs);
}
} else {
synchronized (reader) {
Expand Down

0 comments on commit 4d7516c

Please sign in to comment.