Skip to content

Commit

Permalink
Reset the sequence numbers on Session disconnect to support reconnection
Browse files Browse the repository at this point in the history
  • Loading branch information
mvegter committed Oct 21, 2023
1 parent b0f2080 commit 751029b
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/main/java/com/jcraft/jsch/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public class Session {
private byte[] MACc2s;
private byte[] MACs2c;

// RFC 4253 6.4. each direction must run independently hence we have an incoming and outgoing sequence
private int seqi = 0;
private int seqo = 0;

Expand Down Expand Up @@ -1998,6 +1999,12 @@ public void disconnect() {
}
io = null;
socket = null;

// RFC 4253 6.4. the 'sequence_number' is never reset, even if keys/algorithms are renegotiated later. Hence we only
// reset these on session disconnect as the sequence has to start at zero for the first packet during (re)connect.
seqi = 0;
seqo = 0;

// synchronized(jsch.pool){
// jsch.pool.removeElement(this);
// }
Expand Down
161 changes: 161 additions & 0 deletions src/test/java/com/jcraft/jsch/SessionReconnectIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package com.jcraft.jsch;

import com.github.valfirst.slf4jtest.LoggingEvent;
import com.github.valfirst.slf4jtest.TestLogger;
import com.github.valfirst.slf4jtest.TestLoggerFactory;
import org.apache.commons.codec.digest.DigestUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.output.Slf4jLogConsumer;
import org.testcontainers.images.builder.ImageFromDockerfile;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Base64;
import java.util.List;
import java.util.Random;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

@Testcontainers
public class SessionReconnectIT {

private static final int timeout = 2000;
private static final DigestUtils sha256sum = new DigestUtils(DigestUtils.getSha256Digest());
private static final TestLogger jschLogger = TestLoggerFactory.getTestLogger(JSch.class);
private static final TestLogger sshdLogger =
TestLoggerFactory.getTestLogger(SessionReconnectIT.class);

@TempDir
public Path tmpDir;
private Path in;
private Path out;
private String hash;
private Slf4jLogConsumer sshdLogConsumer;

@Container
public GenericContainer<?> sshd = new GenericContainer<>(new ImageFromDockerfile()
.withFileFromClasspath("dropbear_rsa_host_key", "docker/dropbear_rsa_host_key")
.withFileFromClasspath("authorized_keys", "docker/authorized_keys")
.withFileFromClasspath("Dockerfile", "docker/Dockerfile.dropbear")).withExposedPorts(22);

@BeforeAll
public static void beforeAll() {
JSch.setLogger(new Slf4jLogger());
}

@BeforeEach
public void beforeEach() throws IOException {
if (sshdLogConsumer == null) {
sshdLogConsumer = new Slf4jLogConsumer(sshdLogger);
sshd.followOutput(sshdLogConsumer);
}

in = tmpDir.resolve("in");
out = tmpDir.resolve("out");
Files.createFile(in);
try (OutputStream os = Files.newOutputStream(in)) {
byte[] data = new byte[1024];
for (int i = 0; i < 1024 * 100; i += 1024) {
new Random().nextBytes(data);
os.write(data);
}
}
hash = sha256sum.digestAsHex(in);

jschLogger.clearAll();
sshdLogger.clearAll();
}

@AfterAll
public static void afterAll() {
JSch.setLogger(null);
jschLogger.clearAll();
sshdLogger.clearAll();
}

@Test
public void testReconnectWithExtraAlgorithms() throws Exception {
JSch ssh = createRSAIdentity();
Session session = createSession(ssh);
try {
doSftp(session, false);
fail("exception expected");
} catch (JSchAlgoNegoFailException e) {
// Dropbear does not support rsa-sha2-512/rsa-sha2-256, so add ssh-rsa
String serverHostKey = session.getConfig("server_host_key") + ",ssh-rsa";
String pubkeyAcceptedAlgorithms = session.getConfig("PubkeyAcceptedAlgorithms") + ",ssh-rsa";
session.setConfig("server_host_key", serverHostKey);
session.setConfig("PubkeyAcceptedAlgorithms", pubkeyAcceptedAlgorithms);
doSftp(session, true);
}
}

private JSch createRSAIdentity() throws Exception {
HostKey hostKey = readHostKey(getResourceFile("docker/ssh_host_rsa_key.pub"));
JSch ssh = new JSch();
ssh.addIdentity(getResourceFile("docker/id_rsa"), getResourceFile("docker/id_rsa.pub"), null);
ssh.getHostKeyRepository().add(hostKey, null);
return ssh;
}

private HostKey readHostKey(String fileName) throws Exception {
List<String> lines = Files.readAllLines(Paths.get(fileName), UTF_8);
String[] split = lines.get(0).split("\\s+");
String hostname = String.format("[%s]:%d", sshd.getHost(), sshd.getFirstMappedPort());
return new HostKey(hostname, Base64.getDecoder().decode(split[1]));
}

private Session createSession(JSch ssh) throws Exception {
Session session = ssh.getSession("root", sshd.getHost(), sshd.getFirstMappedPort());
session.setConfig("StrictHostKeyChecking", "yes");
session.setConfig("PreferredAuthentications", "publickey");
return session;
}

private void doSftp(Session session, boolean debugException) throws Exception {
try {
session.setTimeout(timeout);
session.connect();
ChannelSftp sftp = (ChannelSftp) session.openChannel("sftp");
sftp.connect(timeout);
sftp.put(in.toString(), "/root/test");
sftp.get("/root/test", out.toString());
sftp.disconnect();
session.disconnect();
} catch (Exception e) {
if (debugException) {
printInfo();
}
throw e;
}

assertEquals(1024L * 100L, Files.size(out));
assertEquals(hash, sha256sum.digestAsHex(out));
}

private void printInfo() {
jschLogger.getAllLoggingEvents().stream().map(LoggingEvent::getFormattedMessage)
.forEach(System.out::println);
sshdLogger.getAllLoggingEvents().stream().map(LoggingEvent::getFormattedMessage)
.forEach(System.out::println);
System.out.println("");
System.out.println("");
System.out.println("");
}

private String getResourceFile(String fileName) {
return ResourceUtil.getResourceFile(getClass(), fileName);
}
}

0 comments on commit 751029b

Please sign in to comment.