diff --git a/src/main/java/com/actiontech/dble/backend/mysql/nio/handler/transaction/xa/XAAnalysisHandler.java b/src/main/java/com/actiontech/dble/backend/mysql/nio/handler/transaction/xa/XAAnalysisHandler.java index b80d3f2bd4..f5613984d8 100644 --- a/src/main/java/com/actiontech/dble/backend/mysql/nio/handler/transaction/xa/XAAnalysisHandler.java +++ b/src/main/java/com/actiontech/dble/backend/mysql/nio/handler/transaction/xa/XAAnalysisHandler.java @@ -73,6 +73,9 @@ public boolean isExistXid(String xaId) { } private void checkResidualXid(boolean isStartup) { + if (SystemConfig.getInstance().getBackendMode() == SystemConfig.BackendMode.OB) { + return; + } Set usedXaid = getCurrentUsedXaids(); usedXaid.add(DbleServer.getInstance().getXaIDInc()); if (LOGGER.isDebugEnabled()) { diff --git a/src/main/java/com/actiontech/dble/config/helper/GetAndSyncDbInstanceKeyVariables.java b/src/main/java/com/actiontech/dble/config/helper/GetAndSyncDbInstanceKeyVariables.java index f4ace9bb74..d9f12a0267 100644 --- a/src/main/java/com/actiontech/dble/config/helper/GetAndSyncDbInstanceKeyVariables.java +++ b/src/main/java/com/actiontech/dble/config/helper/GetAndSyncDbInstanceKeyVariables.java @@ -53,7 +53,14 @@ public KeyVariables call() { if (columnIsolation == null) { return keyVariables; } - String[] columns = new String[]{COLUMN_LOWER_CASE, COLUMN_AUTOCOMMIT, COLUMN_READONLY, COLUMN_MAX_PACKET, columnIsolation, COLUMN_VERSION, COLUMN_BACK_LOG}; + + String[] columns; + if (SystemConfig.getInstance().getBackendMode() == SystemConfig.BackendMode.OB) { + columns = new String[]{COLUMN_LOWER_CASE, COLUMN_AUTOCOMMIT, COLUMN_READONLY, COLUMN_MAX_PACKET, columnIsolation, COLUMN_VERSION}; + } else { + columns = new String[]{COLUMN_LOWER_CASE, COLUMN_AUTOCOMMIT, COLUMN_READONLY, COLUMN_MAX_PACKET, columnIsolation, COLUMN_VERSION, COLUMN_BACK_LOG}; + } + StringBuilder sql = new StringBuilder("select "); for (int i = 0; i < columns.length; i++) { if (i != 0) { @@ -125,7 +132,9 @@ public void onResult(SQLQueryResult> result) { keyVariablesTmp.setTargetMaxPacketSize(SystemConfig.getInstance().getMaxPacketSize() + KeyVariables.MARGIN_PACKET_SIZE); keyVariablesTmp.setReadOnly(result.getResult().get(COLUMN_READONLY).equals("1")); keyVariablesTmp.setVersion(result.getResult().get(COLUMN_VERSION)); - keyVariablesTmp.setBackLog(Integer.parseInt(result.getResult().get(COLUMN_BACK_LOG))); + if (SystemConfig.getInstance().getBackendMode() != SystemConfig.BackendMode.OB) { + keyVariablesTmp.setBackLog(Integer.parseInt(result.getResult().get(COLUMN_BACK_LOG))); + } if (needSync) { boolean checkNeedSync = false; diff --git a/src/main/java/com/actiontech/dble/config/model/SystemConfig.java b/src/main/java/com/actiontech/dble/config/model/SystemConfig.java index b05ed74408..d550760c9b 100644 --- a/src/main/java/com/actiontech/dble/config/model/SystemConfig.java +++ b/src/main/java/com/actiontech/dble/config/model/SystemConfig.java @@ -245,6 +245,10 @@ private SystemConfig() { private String serverCertificateKeyStoreUrl = null; private String serverCertificateKeyStorePwd = null; + private String clientCertificateKeyStoreUrl = null; + private String clientCertificateKeyStorePwd = null; + + private String trustCertificateKeyStoreUrl = null; private String trustCertificateKeyStorePwd = null; @@ -253,8 +257,10 @@ private SystemConfig() { private String gmsslBothPfxPwd = null; private String gmsslRcaPem = null; private String gmsslOcaPem = null; - private boolean supportSSL = false; + private boolean supportFrontSSL = false; + private boolean supportBackSSL = false; + private BackendMode backendMode = BackendMode.MYSQL; private int enableAsyncRelease = 1; //unit: ms private long releaseTimeout = 10L; @@ -309,6 +315,22 @@ public void setServerCertificateKeyStorePwd(String serverCertificateKeyStorePwd) } } + public String getClientCertificateKeyStoreUrl() { + return clientCertificateKeyStoreUrl; + } + + public void setClientCertificateKeyStoreUrl(String clientCertificateKeyStoreUrl) { + this.clientCertificateKeyStoreUrl = clientCertificateKeyStoreUrl; + } + + public String getClientCertificateKeyStorePwd() { + return clientCertificateKeyStorePwd; + } + + public void setClientCertificateKeyStorePwd(String clientCertificateKeyStorePwd) { + this.clientCertificateKeyStorePwd = clientCertificateKeyStorePwd; + } + public String getTrustCertificateKeyStoreUrl() { return trustCertificateKeyStoreUrl; } @@ -1842,15 +1864,30 @@ public void setDataCenter(String dataCenter) { this.dataCenter = dataCenter; } - public boolean isSupportSSL() { - return supportSSL; + public boolean isSupportFrontSSL() { + return supportFrontSSL; } @SuppressWarnings("unused") - public void setSupportSSL(boolean supportSSL) { - this.supportSSL = supportSSL; + public void setSupportFrontSSL(boolean supportFrontSSL) { + this.supportFrontSSL = supportFrontSSL; } + public boolean isSupportBackSSL() { + return supportBackSSL; + } + + public void setSupportBackSSL(boolean supportBackSSL) { + this.supportBackSSL = supportBackSSL; + } + + public BackendMode getBackendMode() { + return backendMode; + } + + public void setBackendMode(BackendMode backendMode) { + this.backendMode = backendMode; + } public int getEnableMemoryBufferMonitor() { return enableMemoryBufferMonitor; @@ -2039,13 +2076,16 @@ public String toString() { ", closeHeartBeatRecord=" + closeHeartBeatRecord + ", serverCertificateKeyStoreUrl=" + serverCertificateKeyStoreUrl + ", serverCertificateKeyStorePwd=" + serverCertificateKeyStorePwd + + ", clientCertificateKeyStoreUrl=" + clientCertificateKeyStoreUrl + + ", clientCertificateKeyStorePwd=" + clientCertificateKeyStorePwd + + ", supportBackSSL=" + supportBackSSL + ", trustCertificateKeyStoreUrl=" + trustCertificateKeyStoreUrl + ", trustCertificateKeyStorePwd=" + trustCertificateKeyStorePwd + ", gmsslBothPfx=" + gmsslBothPfx + ", gmsslBothPfxPwd=" + gmsslBothPfxPwd + ", gmsslRcaPem=" + gmsslRcaPem + ", gmsslOcaPem=" + gmsslOcaPem + - ", supportSSL=" + supportSSL + + ", supportFrontSSL=" + supportFrontSSL + ", enableRoutePenetration=" + enableRoutePenetration + ", routePenetrationRules='" + routePenetrationRules + '\'' + ", enableSessionActiveRatioStat=" + enableSessionActiveRatioStat + @@ -2102,4 +2142,8 @@ private void checkChineseProperty(String val, String name) { problemReporter.warn("Property [ " + name + " ] " + val + " in bootstrap.cnf is illegal,the " + Charset.defaultCharset().name() + " encoding is recommended, Property [ " + name + " ] show be use u4E00-u9FA5a-zA-Z_0-9\\-\\."); } } + + public enum BackendMode { + MYSQL, OB + } } diff --git a/src/main/java/com/actiontech/dble/config/util/ParameterMapping.java b/src/main/java/com/actiontech/dble/config/util/ParameterMapping.java index f9625929ab..4151bb3893 100644 --- a/src/main/java/com/actiontech/dble/config/util/ParameterMapping.java +++ b/src/main/java/com/actiontech/dble/config/util/ParameterMapping.java @@ -82,6 +82,21 @@ public static void mapping(Object target, Properties src, ProblemReporter proble src.remove(propertyName); continue; } + } else if (cls.isEnum()) { + try { + value = Enum.valueOf((Class) cls, (valStr).toUpperCase()); + } catch (IllegalArgumentException e) { + String propertyName = pd.getName(); + String message = getEnumErrorMessage(propertyName, valStr, cls); + if (problemReporter != null) { + problemReporter.warn(message); + errorParameters.add(message); + } else { + LOGGER.warn(message); + } + src.remove(propertyName); + continue; + } } if (value != null) { Method method = pd.getWriteMethod(); @@ -135,6 +150,19 @@ public static Properties mappingFromSystemProperty(Object target, ProblemReporte systemProperties.remove(propertyName); continue; } + } else if (cls.isEnum()) { + try { + value = Enum.valueOf((Class) cls, valStr); + } catch (IllegalArgumentException e) { + String msg = getEnumErrorMessage(propertyName, valStr, cls); + if (problemReporter != null) { + problemReporter.warn(msg); + } else { + LOGGER.warn(msg); + } + systemProperties.remove(propertyName); + continue; + } } if (value != null) { Method method = pd.getWriteMethod(); @@ -276,4 +304,12 @@ private static String getTypeErrorMessage(String name, String values, Class c return sb.toString(); } + private static String getEnumErrorMessage(String name, String values, Class cls) { + String message = getErrorCompatibleMessage(name); + StringBuilder sb = new StringBuilder(message); + sb.append("property [ ").append(name).append(" ] '").append(values).append("' isn't a valid value."); + return sb.toString(); + } + + } diff --git a/src/main/java/com/actiontech/dble/net/connection/AbstractConnection.java b/src/main/java/com/actiontech/dble/net/connection/AbstractConnection.java index 553e1c3a9c..038da1880e 100644 --- a/src/main/java/com/actiontech/dble/net/connection/AbstractConnection.java +++ b/src/main/java/com/actiontech/dble/net/connection/AbstractConnection.java @@ -6,6 +6,7 @@ package com.actiontech.dble.net.connection; import com.actiontech.dble.backend.mysql.proto.handler.Impl.MySQLProtoHandlerImpl; +import com.actiontech.dble.backend.mysql.proto.handler.Impl.SSLProtoHandler; import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandler; import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult; import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResultCode; @@ -18,6 +19,8 @@ import com.actiontech.dble.net.SocketWR; import com.actiontech.dble.net.WriteOutTask; import com.actiontech.dble.net.service.*; +import com.actiontech.dble.net.ssl.IOpenSSLWrapper; +import com.actiontech.dble.net.ssl.SSLWrapperRegistry; import com.actiontech.dble.services.BusinessService; import com.actiontech.dble.services.TransactionOperate; import com.actiontech.dble.services.mysqlauthenticate.MySQLFrontAuthService; @@ -26,6 +29,7 @@ import com.actiontech.dble.statistic.stat.FrontActiveRatioStat; import com.actiontech.dble.util.CompressUtil; import com.actiontech.dble.util.TimeUtil; +import com.actiontech.dble.util.exception.NotSupportException; import com.google.common.base.Strings; import com.google.common.collect.Sets; import org.slf4j.Logger; @@ -93,6 +97,12 @@ public abstract class AbstractConnection implements Connection { protected volatile Boolean requestSSL; + protected volatile boolean isSupportSSL; + protected volatile SSLHandler sslHandler; + protected String sslName; + + protected volatile ByteBuffer netReadBuffer; + public AbstractConnection(NetworkChannel channel, SocketWR socketWR) { this.channel = channel; this.socketWR = socketWR; @@ -102,6 +112,11 @@ public AbstractConnection(NetworkChannel channel, SocketWR socketWR) { this.lastWriteTime = startupTime; this.proto = new MySQLProtoHandlerImpl(this); FrontActiveRatioStat.getInstance().register(this, startupTime); + if (this instanceof BackendConnection) { + this.isSupportSSL = SystemConfig.getInstance().isSupportBackSSL(); + } else if (this instanceof FrontendConnection) { + this.isSupportSSL = SystemConfig.getInstance().isSupportFrontSSL(); + } } public void onReadData(int got) throws IOException { @@ -813,7 +828,177 @@ public void setBottomReadBuffer(ByteBuffer bottomReadBuffer) { } + public void initSSLContext(int protocol) { + if (sslHandler != null) { + return; + } + sslHandler = new SSLHandler(this); + IOpenSSLWrapper sslWrapper = SSLWrapperRegistry.getInstance(protocol); + if (sslWrapper == null) { + throw new NotSupportException("not support " + SSLWrapperRegistry.SSLProtocol.nameOf(protocol)); + } + sslName = SSLWrapperRegistry.SSLProtocol.nameOf(protocol); + sslHandler.setSslWrapper(sslWrapper); + } + + public ByteBuffer ensureReadBufferFree(ByteBuffer oldBuffer, int expectSize) { + ByteBuffer newBuffer = allocate(expectSize < 0 ? processor.getBufferPool().getChunkSize() : expectSize, generateBufferRecordBuilder().withType(BufferType.POOL)); + oldBuffer.flip(); + newBuffer.put(oldBuffer); + setBottomReadBuffer(newBuffer); + + oldBuffer.clear(); + recycle(oldBuffer); + + return newBuffer; + } + + + public void handleSSLData(ByteBuffer dataBuffer) throws IOException { + if (dataBuffer == null) { + return; + } + int offset = 0; + SSLProtoHandler proto = new SSLProtoHandler(this); + boolean hasRemaining = true; + while (hasRemaining) { + ProtoHandlerResult result = proto.handle(dataBuffer, offset, false, true); + switch (result.getCode()) { + case SSL_PROTO_PACKET: + case SSL_CLOSE_PACKET: + if (!result.isHasMorePacket()) { + netReadReachEnd(); + final ByteBuffer tmpReadBuffer = getBottomReadBuffer(); + if (tmpReadBuffer != null) { + tmpReadBuffer.clear(); + } + } + processSSLProto(result.getPacketData(), result.getCode()); + if (!result.isHasMorePacket()) { + dataBuffer.clear(); + } + break; + case SSL_APP_PACKET: + if (!result.isHasMorePacket()) { + netReadReachEnd(); + } + processSSLAppData(result.getPacketData()); + if (!result.isHasMorePacket()) { + dataBuffer.clear(); + } + break; + case BUFFER_PACKET_UNCOMPLETE: + processSSLPacketUnComplete(dataBuffer, offset); + break; + case SSL_BUFFER_NOT_BIG_ENOUGH: + processSSLPacketNotBigEnough(dataBuffer, result.getOffset(), result.getPacketLength()); + break; + default: + break; + } + hasRemaining = result.isHasMorePacket(); + if (hasRemaining) { + offset = result.getOffset(); + } + } + } + + + private void netReadReachEnd() { + // if cur buffer is temper none direct byte buffer and not + // received large message in recent 30 seconds + // then change to direct buffer for performance + ByteBuffer localReadBuffer = netReadBuffer; + if (localReadBuffer != null && !localReadBuffer.isDirect() && lastLargeMessageTime < lastReadTime - 30 * 1000L) { // used temp heap + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("change to direct con read buffer ,cur temp buf size :" + localReadBuffer.capacity()); + } + recycle(localReadBuffer); + netReadBuffer = allocate(readBufferChunk, generateBufferRecordBuilder().withType(BufferType.POOL)); + } else { + if (localReadBuffer != null) { + IODelayProvider.inReadReachEnd(); + localReadBuffer.clear(); + } + } + } + + private void processSSLAppData(byte[] packetData) throws IOException { + if (packetData == null) return; + sslHandler.unwrapAppData(packetData); + handleNonSSL(getBottomReadBuffer()); + } + + public void processSSLPacketNotBigEnough(ByteBuffer buffer, int offset, final int pkgLength) { + ByteBuffer newBuffer = allocate(pkgLength, generateBufferRecordBuilder().withType(BufferType.POOL)); + buffer.position(offset); + newBuffer.put(buffer); + this.netReadBuffer = newBuffer; + recycle(buffer); + } + + private void processSSLPacketUnComplete(ByteBuffer buffer, int offset) { + if (buffer == null) { + return; + } + buffer.limit(buffer.position()); + buffer.position(offset); + netReadBuffer = buffer.compact(); + } + protected abstract void handleNonSSL(ByteBuffer dataBuffer) throws IOException; + + public void doSSLHandShake(byte[] data) { + try { + if (!isUseSSL()) { + close("SSL not initialized"); + return; + } + if (!sslHandler.isCreateEngine()) { + sslHandler.createEngine(); + } + sslHandler.handShake(data); + } catch (SSLException e) { + LOGGER.warn("SSL handshake failed, exception: ", e); + close("SSL handshake failed"); + } catch (IOException e) { + LOGGER.warn("SSL initialization failed, exception: ", e); + close("SSL initialization failed"); + } } + + public void sendSSLHandShake(int protocol) { + try { + this.initSSLContext(protocol); + if (!isUseSSL()) { + close("SSL not initialized"); + return; + } + if (!sslHandler.isCreateEngine()) { + sslHandler.createEngine(); + } + sslHandler.sendhandShake(); + } catch (SSLException e) { + LOGGER.warn("SSL handshake failed, exception: ", e); + close("SSL handshake failed"); + } catch (IOException e) { + LOGGER.warn("SSL initialization failed, exception: ", e); + close("SSL initialization failed"); + } + } + + + public boolean isUseSSL() { + return sslHandler != null; + } + + public boolean isSupportSSL() { + return isSupportSSL; + } + + public boolean isSSLHandshakeSuccess() { + return sslHandler != null && sslHandler.isHandshakeSuccess(); + } + } diff --git a/src/main/java/com/actiontech/dble/net/connection/BackendConnection.java b/src/main/java/com/actiontech/dble/net/connection/BackendConnection.java index fb44880706..6e8c18d292 100644 --- a/src/main/java/com/actiontech/dble/net/connection/BackendConnection.java +++ b/src/main/java/com/actiontech/dble/net/connection/BackendConnection.java @@ -9,6 +9,8 @@ import com.actiontech.dble.backend.mysql.nio.handler.ResponseHandler; import com.actiontech.dble.backend.pool.PooledConnectionListener; import com.actiontech.dble.backend.pool.ReadTimeStatusInstance; +import com.actiontech.dble.buffer.BufferType; +import com.actiontech.dble.config.model.SystemConfig; import com.actiontech.dble.config.model.db.DbInstanceConfig; import com.actiontech.dble.net.IOProcessor; import com.actiontech.dble.net.SocketWR; @@ -21,6 +23,8 @@ import com.actiontech.dble.singleton.FlowController; import com.actiontech.dble.util.TimeUtil; +import javax.net.ssl.SSLException; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.NetworkChannel; @@ -128,6 +132,8 @@ public void release() { @Override public synchronized void close(final String reason) { + if (isUseSSL()) sslHandler.close(); + if (getCloseReason() == null || !getCloseReason().equals(reason)) LOGGER.info("connection id " + id + " mysqlId " + threadId + " close for reason " + reason); boolean isAuthed = !this.getService().isFakeClosed() && !(this.getService() instanceof AuthService); @@ -217,6 +223,12 @@ private void writeClose(ByteBuffer buffer) { this.socketWR.doNextWriteCheck(); } + + protected void handleNonSSL(ByteBuffer dataBuffer) throws IOException { + super.handle(dataBuffer, false); + } + + public long getThreadId() { return threadId; } @@ -251,4 +263,96 @@ public String toString() { // show all public String toString2() { return "BackendConnection[id = " + id + " host = " + host + " port = " + port + " localPort = " + localPort + " mysqlId = " + threadId + " db config = " + instance + "]"; } + + + @Override + public void compactReadBuffer(ByteBuffer dataBuffer, int offset, boolean isSSL) throws IOException { + if (dataBuffer == null) { + return; + } + if (isSupportSSL && isSSL) { + dataBuffer.flip(); + dataBuffer.position(offset); + int len = netReadBuffer.position() + (dataBuffer.limit() - dataBuffer.position()); + if (netReadBuffer.capacity() < len) { + processSSLPacketNotBigEnough(netReadBuffer, 0, len); + } + this.netReadBuffer.put(dataBuffer); + dataBuffer.clear(); + handleSSLData(netReadBuffer); + } else { + dataBuffer.limit(dataBuffer.position()); + dataBuffer.position(offset); + setBottomReadBuffer(dataBuffer.compact()); + } + } + + @Override + public ByteBuffer wrap(ByteBuffer orgBuffer) throws SSLException { + if (!isUseSSL()) return orgBuffer; + return sslHandler.wrapAppData(orgBuffer); + } + + @Override + public ByteBuffer findReadBuffer() { + if (isSupportSSL && maybeUseSSL()) { + if (this.netReadBuffer == null) { + netReadBuffer = allocate(processor.getBufferPool().getChunkSize(), generateBufferRecordBuilder().withType(BufferType.POOL)); + } + return netReadBuffer; + } else { + //only recycle this read buffer + recycleNetReadBuffer(); + return super.findReadBuffer(); + } + } + + private void recycleNetReadBuffer() { + if (this.netReadBuffer != null) { + this.recycle(this.netReadBuffer); + this.netReadBuffer = null; + } + } + + @Override + ByteBuffer getReadBuffer() { + if (isSupportSSL && maybeUseSSL()) { + return netReadBuffer; + } else { + return super.getReadBuffer(); + } + } + + private void transferToReadBuffer(ByteBuffer dataBuffer) { + if (!isSupportSSL || !maybeUseSSL()) return; + dataBuffer.flip(); + ByteBuffer readBuffer = findBottomReadBuffer(); + int len = readBuffer.position() + dataBuffer.limit(); + if (readBuffer.capacity() < len) { + readBuffer = ensureReadBufferFree(readBuffer, len); + } + readBuffer.put(dataBuffer); + dataBuffer.clear(); + } + + + @Override + protected void handle(ByteBuffer dataBuffer, boolean isContainSSLData) throws IOException { + if (this.isSupportSSL && isUseSSL() && isSSLHandshakeSuccess()) { + //after ssl-client hello + handleSSLData(dataBuffer); + } else { + //ssl buffer -> bottomRead buffer + transferToReadBuffer(dataBuffer); + if (maybeUseSSL()) { + //ssl login request(non ssl)&client hello(ssl) + super.handle(getBottomReadBuffer(), true); + } else { + //no ssl + handleNonSSL(getBottomReadBuffer()); + } + } + } + + } diff --git a/src/main/java/com/actiontech/dble/net/connection/FrontendConnection.java b/src/main/java/com/actiontech/dble/net/connection/FrontendConnection.java index fc73a4b724..c10f404523 100644 --- a/src/main/java/com/actiontech/dble/net/connection/FrontendConnection.java +++ b/src/main/java/com/actiontech/dble/net/connection/FrontendConnection.java @@ -5,23 +5,17 @@ package com.actiontech.dble.net.connection; -import com.actiontech.dble.backend.mysql.proto.handler.Impl.SSLProtoHandler; -import com.actiontech.dble.backend.mysql.proto.handler.ProtoHandlerResult; -import com.actiontech.dble.btrace.provider.IODelayProvider; import com.actiontech.dble.buffer.BufferType; import com.actiontech.dble.config.model.SystemConfig; import com.actiontech.dble.net.IOProcessor; import com.actiontech.dble.net.SocketWR; import com.actiontech.dble.net.service.AbstractService; import com.actiontech.dble.net.service.AuthService; -import com.actiontech.dble.net.ssl.OpenSSLWrapper; -import com.actiontech.dble.net.ssl.SSLWrapperRegistry; import com.actiontech.dble.services.BusinessService; import com.actiontech.dble.services.FrontendService; import com.actiontech.dble.services.mysqlauthenticate.MySQLChangeUserService; import com.actiontech.dble.singleton.FlowController; import com.actiontech.dble.util.TimeUtil; -import com.actiontech.dble.util.exception.NotSupportException; import javax.net.ssl.SSLException; import java.io.IOException; @@ -45,10 +39,6 @@ public class FrontendConnection extends AbstractConnection { //skip idleTimeout checks private boolean skipCheck; - private final boolean isSupportSSL; - protected volatile ByteBuffer netReadBuffer; - private volatile SSLHandler sslHandler; - private String sslName; public FrontendConnection(NetworkChannel channel, SocketWR socketWR, boolean isManager) throws IOException { super(channel, socketWR); @@ -67,7 +57,6 @@ public FrontendConnection(NetworkChannel channel, SocketWR socketWR, boolean isM this.localPort = remoteAddress.getPort(); this.idleTimeout = SystemConfig.getInstance().getIdleTimeout(); this.isCleanUp = new AtomicBoolean(false); - this.isSupportSSL = SystemConfig.getInstance().isSupportSSL(); } @Override @@ -92,38 +81,6 @@ protected void handleNonSSL(ByteBuffer dataBuffer) throws IOException { super.handle(dataBuffer, false); } - @Override - public void initSSLContext(int protocol) { - if (sslHandler != null) { - return; - } - sslHandler = new SSLHandler(this); - OpenSSLWrapper sslWrapper = SSLWrapperRegistry.getInstance(protocol); - if (sslWrapper == null) { - throw new NotSupportException("not support " + SSLWrapperRegistry.SSLProtocol.nameOf(protocol)); - } - sslName = SSLWrapperRegistry.SSLProtocol.nameOf(protocol); - sslHandler.setSslWrapper(sslWrapper); - } - - public void doSSLHandShake(byte[] data) { - try { - if (!isUseSSL()) { - close("SSL not initialized"); - return; - } - if (!sslHandler.isCreateEngine()) { - sslHandler.createEngine(); - } - sslHandler.handShake(data); - } catch (SSLException e) { - LOGGER.warn("SSL handshake failed, exception: ", e); - close("SSL handshake failed"); - } catch (IOException e) { - LOGGER.warn("SSL initialization failed, exception: ", e); - close("SSL initialization failed"); - } - } private void transferToReadBuffer(ByteBuffer dataBuffer) { if (!isSupportSSL || !maybeUseSSL()) return; @@ -137,96 +94,7 @@ private void transferToReadBuffer(ByteBuffer dataBuffer) { dataBuffer.clear(); } - public void handleSSLData(ByteBuffer dataBuffer) throws IOException { - if (dataBuffer == null) { - return; - } - int offset = 0; - SSLProtoHandler proto = new SSLProtoHandler(this); - boolean hasRemaining = true; - while (hasRemaining) { - ProtoHandlerResult result = proto.handle(dataBuffer, offset, false, true); - switch (result.getCode()) { - case SSL_PROTO_PACKET: - case SSL_CLOSE_PACKET: - if (!result.isHasMorePacket()) { - netReadReachEnd(); - final ByteBuffer tmpReadBuffer = getBottomReadBuffer(); - if (tmpReadBuffer != null) { - tmpReadBuffer.clear(); - } - } - processSSLProto(result.getPacketData(), result.getCode()); - if (!result.isHasMorePacket()) { - dataBuffer.clear(); - } - break; - case SSL_APP_PACKET: - if (!result.isHasMorePacket()) { - netReadReachEnd(); - } - processSSLAppData(result.getPacketData()); - if (!result.isHasMorePacket()) { - dataBuffer.clear(); - } - break; - case BUFFER_PACKET_UNCOMPLETE: - processSSLPacketUnComplete(dataBuffer, offset); - break; - case SSL_BUFFER_NOT_BIG_ENOUGH: - processSSLPacketNotBigEnough(dataBuffer, result.getOffset(), result.getPacketLength()); - break; - default: - break; - } - hasRemaining = result.isHasMorePacket(); - if (hasRemaining) { - offset = result.getOffset(); - } - } - } - - private void netReadReachEnd() { - // if cur buffer is temper none direct byte buffer and not - // received large message in recent 30 seconds - // then change to direct buffer for performance - ByteBuffer localReadBuffer = netReadBuffer; - if (localReadBuffer != null && !localReadBuffer.isDirect() && lastLargeMessageTime < lastReadTime - 30 * 1000L) { // used temp heap - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("change to direct con read buffer ,cur temp buf size :" + localReadBuffer.capacity()); - } - recycle(localReadBuffer); - netReadBuffer = allocate(readBufferChunk, generateBufferRecordBuilder().withType(BufferType.POOL)); - } else { - if (localReadBuffer != null) { - IODelayProvider.inReadReachEnd(); - localReadBuffer.clear(); - } - } - } - - private void processSSLAppData(byte[] packetData) throws IOException { - if (packetData == null) return; - sslHandler.unwrapAppData(packetData); - handleNonSSL(getBottomReadBuffer()); - } - public void processSSLPacketNotBigEnough(ByteBuffer buffer, int offset, final int pkgLength) { - ByteBuffer newBuffer = allocate(pkgLength, generateBufferRecordBuilder().withType(BufferType.POOL)); - buffer.position(offset); - newBuffer.put(buffer); - this.netReadBuffer = newBuffer; - recycle(buffer); - } - - private void processSSLPacketUnComplete(ByteBuffer buffer, int offset) { - if (buffer == null) { - return; - } - buffer.limit(buffer.position()); - buffer.position(offset); - netReadBuffer = buffer.compact(); - } @Override public void businessClose(String reason) { @@ -313,17 +181,6 @@ public void compactReadBuffer(ByteBuffer dataBuffer, int offset, boolean isSSL) } - public ByteBuffer ensureReadBufferFree(ByteBuffer oldBuffer, int expectSize) { - ByteBuffer newBuffer = allocate(expectSize < 0 ? processor.getBufferPool().getChunkSize() : expectSize, generateBufferRecordBuilder().withType(BufferType.POOL)); - oldBuffer.flip(); - newBuffer.put(oldBuffer); - setBottomReadBuffer(newBuffer); - - oldBuffer.clear(); - recycle(oldBuffer); - - return newBuffer; - } public boolean isIdleTimeout() { if (!(getService() instanceof AuthService)) { @@ -384,9 +241,6 @@ public void setSkipCheck(boolean skipCheck) { this.skipCheck = skipCheck; } - public boolean isUseSSL() { - return sslHandler != null; - } public String toString() { return "FrontendConnection[id = " + id + " port = " + port + " host = " + host + " local_port = " + localPort + " isManager = " + isManager() + " startupTime = " + startupTime + " skipCheck = " + isSkipCheck() + " isFlowControl = " + isFrontWriteFlowControlled() + " onlyTcpConnect = " + isOnlyFrontTcpConnected() + " ssl = " + (isUseSSL() ? sslName : "no") + "]"; diff --git a/src/main/java/com/actiontech/dble/net/connection/SSLHandler.java b/src/main/java/com/actiontech/dble/net/connection/SSLHandler.java index 4814d68162..72040d98ad 100644 --- a/src/main/java/com/actiontech/dble/net/connection/SSLHandler.java +++ b/src/main/java/com/actiontech/dble/net/connection/SSLHandler.java @@ -6,7 +6,7 @@ package com.actiontech.dble.net.connection; import com.actiontech.dble.net.service.WriteFlags; -import com.actiontech.dble.net.ssl.OpenSSLWrapper; +import com.actiontech.dble.net.ssl.IOpenSSLWrapper; import com.actiontech.dble.util.ByteBufferUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,16 +23,17 @@ public class SSLHandler { protected static final Logger LOGGER = LoggerFactory.getLogger(SSLHandler.class); - private final FrontendConnection con; + private final AbstractConnection con; private final NetworkChannel channel; private volatile ByteBuffer decryptOut; - private OpenSSLWrapper sslWrapper; + private IOpenSSLWrapper sslWrapper; private volatile SSLEngine engine; private volatile boolean isHandshakeSuccess = false; - public SSLHandler(FrontendConnection con) { + + public SSLHandler(AbstractConnection con) { this.con = con; this.channel = con.getChannel(); } @@ -41,7 +42,14 @@ public void createEngine() throws IOException { if (sslWrapper == null) { return; } - this.engine = sslWrapper.appleSSLEngine(true); + + if (con instanceof BackendConnection) { + this.engine = sslWrapper.createClientSSLEngine(); + engine.beginHandshake(); + } else if (con instanceof FrontendConnection) { + this.engine = sslWrapper.createServerSSLEngine(true); + } + if (this.channel instanceof SocketChannel) { ((SocketChannel) this.channel).configureBlocking(false); } @@ -52,6 +60,26 @@ public void handShake(byte[] data) throws SSLException { unwrapNonAppData(data); } + + public void sendhandShake() throws SSLException { + + try { + final SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); + + switch (handshakeStatus) { + case NEED_WRAP: + wrapNonAppData(); + break; + default: + throw new IllegalStateException("unknown handshake status: " + handshakeStatus); + } + + } catch (SSLException e) { + LOGGER.warn("during the handshake, unwrap data exception: ", e); + con.close("during the handshake, unwrap data fail"); + } + } + /** * receive and process the SSL handshake protocol initiated by the client */ @@ -61,16 +89,17 @@ private void unwrapNonAppData(byte[] data) { in.flip(); try { - for (; ; ) { + final SSLEngineResult result = unwrap(engine, in); final Status status = result.getStatus(); - final SSLEngineResult.HandshakeStatus handshakeStatus = result.getHandshakeStatus(); + final int produced = result.bytesProduced(); final int consumed = result.bytesConsumed(); if (status == Status.CLOSED) { return; } - + for (; ; ) { + final SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); switch (handshakeStatus) { case NEED_WRAP: wrapNonAppData(); @@ -79,14 +108,24 @@ private void unwrapNonAppData(byte[] data) { runDelegatedTasks(); break; case FINISHED: - /*setHandshakeSuccess(); - continue;*/ + setHandshakeSuccess(); + break; case NEED_UNWRAP: + break; case NOT_HANDSHAKING: + LOGGER.info("connection {} migrate status to NOT_HANDSHAKING", con); + setHandshakeSuccess(); break; default: throw new IllegalStateException("unknown handshake status: " + handshakeStatus); } + + if (handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED || handshakeStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + break; + } + if (handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { + break; + } if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) { break; } @@ -181,6 +220,9 @@ private void wrapNonAppData() throws SSLException { default: throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); } + if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_WRAP) { + break; + } if (result.bytesProduced() == 0) { break; } @@ -294,7 +336,7 @@ private ByteBuffer ensure(ByteBuffer oldBuffer, int size) { return newBuffer; } - public void setSslWrapper(OpenSSLWrapper sslWrapper) { + public void setSslWrapper(IOpenSSLWrapper sslWrapper) { this.sslWrapper = sslWrapper; } diff --git a/src/main/java/com/actiontech/dble/net/mysql/SSLRequestPacket.java b/src/main/java/com/actiontech/dble/net/mysql/SSLRequestPacket.java new file mode 100644 index 0000000000..461efe61e1 --- /dev/null +++ b/src/main/java/com/actiontech/dble/net/mysql/SSLRequestPacket.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2016-2023 ActionTech. + * based on code by MyCATCopyrightHolder Copyright (c) 2013, OpenCloudDB/MyCAT. + * License: http://www.gnu.org/licenses/gpl.html GPL version 2 or higher. + */ +package com.actiontech.dble.net.mysql; + +import com.actiontech.dble.backend.mysql.BufferUtil; +import com.actiontech.dble.backend.mysql.MySQLMessage; +import com.actiontech.dble.backend.mysql.StreamUtil; +import com.actiontech.dble.config.Capabilities; +import com.actiontech.dble.net.connection.AbstractConnection; +import com.actiontech.dble.services.mysqlsharding.MySQLResponseService; +import org.apache.commons.lang.NotImplementedException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; + + +/** + * From client to server during initial handshake. + *

+ *

+ * Bytes                        Name
+ * -----                        ----
+ * 4                            client_flags
+ * 4                            max_packet_size
+ * 1                            charset_number
+ * 23                           (filler) always 0x00...
+ *
+ * @see https://dev.mysql.com/doc/dev/mysql-server/9.1.0/page_protocol_connection_phase_packets_protocol_ssl_request.html
+ * 
+ * + * @author mycat + */ +public class SSLRequestPacket extends MySQLPacket { + private static final Logger LOGGER = LoggerFactory.getLogger(SSLRequestPacket.class); + private static final byte[] FILLER = new byte[23]; + + private long clientFlags; + private long maxPacketSize; + private int charsetIndex; + + private byte[] extra; // from FILLER(23) + private String tenant = ""; + private boolean multStatementAllow = false; + + private boolean isSSLRequest = false; + + + public void read(byte[] data) { + throw new NotImplementedException(); + } + + public void write(OutputStream out) throws IOException { + // StreamUtil.writeUB3(out, calcPacketSize() - 1); //todo:?存疑 + StreamUtil.writeUB3(out, calcPacketSize()); + StreamUtil.write(out, packetId); + StreamUtil.writeUB4(out, clientFlags); // capability flags + StreamUtil.writeUB4(out, maxPacketSize); + StreamUtil.write(out, (byte) charsetIndex); + out.write(FILLER); + } + + @Override + public void write(MySQLResponseService service) { + ByteBuffer buffer = service.allocate(); + BufferUtil.writeUB3(buffer, calcPacketSize()); + buffer.put(packetId); + BufferUtil.writeUB4(buffer, clientFlags); // capability flags + BufferUtil.writeUB4(buffer, maxPacketSize); // max-packet size + buffer.put((byte) charsetIndex); //character set + buffer = service.writeToBuffer(FILLER, buffer); // reserved (all [0]) + + if ((clientFlags & Capabilities.CLIENT_PLUGIN_AUTH) != 0) { + //if use the mysql_native_password is used for auth this need be replay + BufferUtil.writeWithNull(buffer, HandshakeV10Packet.NATIVE_PASSWORD_PLUGIN); + } + + service.writeDirectly(buffer, getLastWriteFlag()); + } + + + public void bufferWrite(OutputStream out) throws IOException { + // if (database != null) { + StreamUtil.writeUB3(out, calcPacketSizeWithKey()); //todo:?存疑 + // } else { + // StreamUtil.writeUB3(out, calcPacketSizeWithKey() - 1); + // } + StreamUtil.write(out, packetId); + StreamUtil.writeUB4(out, clientFlags); + StreamUtil.writeUB4(out, maxPacketSize); + StreamUtil.write(out, (byte) charsetIndex); + out.write(FILLER); + + } + + @Override + public void bufferWrite(AbstractConnection c) { + + ByteBuffer buffer = c.allocate(); + BufferUtil.writeUB3(buffer, calcPacketSizeWithKey()); + buffer.put(packetId); + BufferUtil.writeUB4(buffer, clientFlags); // capability flags + BufferUtil.writeUB4(buffer, maxPacketSize); // max-packet size + buffer.put((byte) charsetIndex); //character set + buffer = c.getService().writeToBuffer(FILLER, buffer); // reserved (all [0]) + + + c.getService().writeDirectly(buffer, getLastWriteFlag()); + + } + + + @Override + public int calcPacketSize() { + int size = 32; // 4+4+1+23; + + return size; + } + + public int calcPacketSizeWithKey() { + int size = 32; // 4+4+1+23; + return size; + } + + @Override + protected String getPacketInfo() { + return "MySQL Authentication Packet"; + } + + public long getClientFlags() { + return clientFlags; + } + + public void setClientFlags(long clientFlags) { + this.clientFlags = clientFlags; + } + + public long getMaxPacketSize() { + return maxPacketSize; + } + + public void setMaxPacketSize(long maxPacketSize) { + this.maxPacketSize = maxPacketSize; + } + + public int getCharsetIndex() { + return charsetIndex; + } + + public void setCharsetIndex(int charsetIndex) { + this.charsetIndex = charsetIndex; + } + + public byte[] getExtra() { + return extra; + } + + + public String getTenant() { + return tenant; + } + + + public boolean isMultStatementAllow() { + return multStatementAllow; + } + + public boolean getIsSSLRequest() { + return isSSLRequest; + } + + public boolean checkSSLRequest(MySQLMessage mm) { + if (mm.position() == mm.length() && (clientFlags & Capabilities.CLIENT_SSL) != 0) { + isSSLRequest = true; + return true; + } else { + return false; + } + } + + + @Override + public boolean isEndOfQuery() { + return true; + } +} diff --git a/src/main/java/com/actiontech/dble/net/ssl/GMSslWrapper.java b/src/main/java/com/actiontech/dble/net/ssl/GMSslWrapper.java index 624eef3116..2a262f027e 100644 --- a/src/main/java/com/actiontech/dble/net/ssl/GMSslWrapper.java +++ b/src/main/java/com/actiontech/dble/net/ssl/GMSslWrapper.java @@ -8,6 +8,7 @@ import com.actiontech.dble.config.model.SystemConfig; import com.actiontech.dble.net.factory.TrustAllManager; import com.actiontech.dble.util.StringUtil; +import org.apache.commons.lang.NotImplementedException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,7 +20,7 @@ import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; -public class GMSslWrapper extends OpenSSLWrapper { +public class GMSslWrapper implements IOpenSSLWrapper { private static final Logger LOGGER = LoggerFactory.getLogger(GMSslWrapper.class); public static final Integer PROTOCOL = 2; @@ -27,6 +28,7 @@ public class GMSslWrapper extends OpenSSLWrapper { private SSLContext context; + @Override public boolean initContext() { try { @@ -105,7 +107,8 @@ private static TrustManager[] createTrustManagers(String gmsslRcaPem, String gms } - public SSLEngine appleSSLEngine(boolean isAuthClient) { + @Override + public SSLEngine createServerSSLEngine(boolean isAuthClient) { SSLEngine engine = context.createSSLEngine(); engine.setUseClientMode(false); engine.setEnabledProtocols("GMSSLv1.1".split(",")); @@ -116,4 +119,8 @@ public SSLEngine appleSSLEngine(boolean isAuthClient) { return engine; } + @Override + public SSLEngine createClientSSLEngine() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/com/actiontech/dble/net/ssl/IOpenSSLWrapper.java b/src/main/java/com/actiontech/dble/net/ssl/IOpenSSLWrapper.java new file mode 100644 index 0000000000..217a245ad5 --- /dev/null +++ b/src/main/java/com/actiontech/dble/net/ssl/IOpenSSLWrapper.java @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2016-2023 ActionTech. + * based on code by MyCATCopyrightHolder Copyright (c) 2013, OpenCloudDB/MyCAT. + * License: http://www.gnu.org/licenses/gpl.html GPL version 2 or higher. + */ + +package com.actiontech.dble.net.ssl; + +import javax.net.ssl.SSLEngine; + +/** + * @author dcy + * Create Date: 2024-12-05 + */ +public interface IOpenSSLWrapper { + public SSLEngine createClientSSLEngine(); + + public boolean initContext(); + + public SSLEngine createServerSSLEngine(boolean isAuthClient); +} diff --git a/src/main/java/com/actiontech/dble/net/ssl/OpenSSLWrapper.java b/src/main/java/com/actiontech/dble/net/ssl/OpenSSLWrapper.java index fffe64c6da..3133ad3f32 100644 --- a/src/main/java/com/actiontech/dble/net/ssl/OpenSSLWrapper.java +++ b/src/main/java/com/actiontech/dble/net/ssl/OpenSSLWrapper.java @@ -20,17 +20,25 @@ import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; -public class OpenSSLWrapper { +public class OpenSSLWrapper implements IOpenSSLWrapper { private static final Logger LOGGER = LoggerFactory.getLogger(OpenSSLWrapper.class); private static final String PROTO = "TLS"; - private SSLContext context; + private SSLContext clientContext; + private SSLContext serverContext; public static final Integer PROTOCOL = 1; protected static final String NAME = "OpenSSL"; + @Override public boolean initContext() { + final boolean a = initClientContext(); + final boolean b = initServerContext(); + return a || b; + } + + private boolean initServerContext() { String serverCertificateKeyStoreUrl = SystemConfig.getInstance().getServerCertificateKeyStoreUrl(); String serverCertificateKeyStorePwd = SystemConfig.getInstance().getServerCertificateKeyStorePwd(); String trustCertificateKeyStoreUrl = SystemConfig.getInstance().getTrustCertificateKeyStoreUrl(); @@ -50,12 +58,12 @@ public boolean initContext() { return false; } - context = SSLContext.getInstance(PROTO); + serverContext = SSLContext.getInstance(PROTO); KeyManager[] keyM = createKeyManagers(serverCertificateKeyStoreUrl, serverCertificateKeyStorePwd); TrustManager[] trustM = StringUtil.isBlank(trustCertificateKeyStoreUrl) ? null : createTrustManagers(trustCertificateKeyStoreUrl, trustCertificateKeyStorePwd); - context.init(keyM, trustM, null); + serverContext.init(keyM, trustM, null); return true; } catch (Exception e) { LOGGER.error("OpenSSL initialization exception: ", e); @@ -63,6 +71,39 @@ public boolean initContext() { return false; } + private boolean initClientContext() { + final String clientCertificateKeyStoreUrl = SystemConfig.getInstance().getClientCertificateKeyStoreUrl(); + final String clientCertificateKeyStorePwd = SystemConfig.getInstance().getClientCertificateKeyStorePwd(); + String trustCertificateKeyStoreUrl = SystemConfig.getInstance().getTrustCertificateKeyStoreUrl(); + String trustCertificateKeyStorePwd = SystemConfig.getInstance().getTrustCertificateKeyStorePwd(); + try { + + if (!StringUtil.isBlank(trustCertificateKeyStoreUrl) && StringUtil.isBlank(trustCertificateKeyStorePwd)) { + LOGGER.warn("Please set the correct [trustCertificateKeyStoreUrl] value."); + return false; + } + + clientContext = SSLContext.getInstance(PROTO); + + + TrustManager[] trustM = StringUtil.isBlank(trustCertificateKeyStoreUrl) ? null : createTrustManagers(trustCertificateKeyStoreUrl, trustCertificateKeyStorePwd); + + if (StringUtil.isBlank(clientCertificateKeyStorePwd) && StringUtil.isBlank(clientCertificateKeyStoreUrl)) { + LOGGER.warn("doesn't detect client Certificate for server ssl, use One-way Authentication instead."); + clientContext.init(null, trustM, null); + } else { + KeyManager[] keyM = createKeyManagers(clientCertificateKeyStoreUrl, clientCertificateKeyStorePwd); + clientContext.init(keyM, trustM, null); + } + + return true; + } catch (Exception e) { + LOGGER.error("OpenSSL initialization exception: ", e); + } + return false; + } + + private static KeyManager[] createKeyManagers(String filepath, String keystorePassword) throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException { KeyStore keyStore = KeyStore.getInstance("JKS"); @@ -90,13 +131,14 @@ private static TrustManager[] createTrustManagers(String filepath, String keysto return trustFactory.getTrustManagers(); } - public SSLEngine appleSSLEngine(boolean isAuthClient) { - SSLEngine engine = context.createSSLEngine(); + @Override + public SSLEngine createServerSSLEngine(boolean isAuthClient) { + SSLEngine engine = serverContext.createSSLEngine(); engine.setUseClientMode(false); - /*engine.setEnabledCipherSuites(context.getServerSocketFactory().getSupportedCipherSuites()); - engine.setEnabledProtocols(new String[]{"TLSv1.1","TLSv1.2"});*/ + // engine.setEnabledCipherSuites(serverContext.getServerSocketFactory().getSupportedCipherSuites()); + engine.setEnabledProtocols(new String[]{"TLSv1.1", "TLSv1.2"}); if (isAuthClient) { engine.setWantClientAuth(true); // request the client authentication. // engine.setNeedClientAuth(true); // require client authentication. @@ -104,4 +146,15 @@ public SSLEngine appleSSLEngine(boolean isAuthClient) { return engine; } + @Override + public SSLEngine createClientSSLEngine() { + SSLEngine engine = clientContext.createSSLEngine(); + engine.setUseClientMode(true); + engine.setEnabledProtocols(new String[]{"TLSv1.1", "TLSv1.2"}); + + /*engine.setEnabledCipherSuites(context.getServerSocketFactory().getSupportedCipherSuites()); + engine.setEnabledProtocols(new String[]{"TLSv1.1","TLSv1.2"});*/ + + return engine; + } } diff --git a/src/main/java/com/actiontech/dble/net/ssl/SSLWrapperRegistry.java b/src/main/java/com/actiontech/dble/net/ssl/SSLWrapperRegistry.java index 20f57f09a2..c16924812b 100644 --- a/src/main/java/com/actiontech/dble/net/ssl/SSLWrapperRegistry.java +++ b/src/main/java/com/actiontech/dble/net/ssl/SSLWrapperRegistry.java @@ -11,7 +11,7 @@ public final class SSLWrapperRegistry { - protected static final Map SSL_CONTEXT_REGISTRY = Maps.newHashMap(); + protected static final Map SSL_CONTEXT_REGISTRY = Maps.newHashMap(); static { register(OpenSSLWrapper.PROTOCOL, new OpenSSLWrapper()); @@ -21,11 +21,11 @@ public final class SSLWrapperRegistry { private SSLWrapperRegistry() { } - static void register(int protocol, OpenSSLWrapper sslWrapper) { + static void register(int protocol, IOpenSSLWrapper sslWrapper) { SSL_CONTEXT_REGISTRY.put(protocol, sslWrapper); } - public static OpenSSLWrapper getInstance(int protocol) { + public static IOpenSSLWrapper getInstance(int protocol) { return SSL_CONTEXT_REGISTRY.get(protocol); } diff --git a/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLBackAuthService.java b/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLBackAuthService.java index 692a875326..67cb508a0f 100644 --- a/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLBackAuthService.java +++ b/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLBackAuthService.java @@ -17,10 +17,8 @@ import com.actiontech.dble.net.ConnectionException; import com.actiontech.dble.net.connection.BackendConnection; import com.actiontech.dble.net.mysql.*; -import com.actiontech.dble.net.service.AuthResultInfo; -import com.actiontech.dble.net.service.AuthService; -import com.actiontech.dble.net.service.ServiceTask; -import com.actiontech.dble.net.service.WriteFlags; +import com.actiontech.dble.net.service.*; +import com.actiontech.dble.net.ssl.OpenSSLWrapper; import com.actiontech.dble.services.BackendService; import com.actiontech.dble.services.factorys.BusinessServiceFactory; import com.actiontech.dble.services.mysqlsharding.MySQLResponseService; @@ -167,7 +165,13 @@ private void handleHandshake(byte[] data) { String serverPlugin = new String(handshakePacket.getAuthPluginName()); try { pluginName = PluginName.valueOf(serverPlugin); - sendAuthPacket(++data[3]); + if (connection.isSupportSSL()) {//todo:move config in dbinstance scope + sendSSLRequestPacket(++data[3]); + connection.sendSSLHandShake(OpenSSLWrapper.PROTOCOL); + } else { + sendAuthPacket(++data[3]); + } + } catch (IllegalArgumentException | NoSuchAlgorithmException e) { String authPluginErrorMessage = "Client don't support the password plugin " + serverPlugin + ",please check the default auth Plugin"; throw new RuntimeException(authPluginErrorMessage); @@ -188,6 +192,17 @@ private void sendAuthPacket(byte packetId) throws NoSuchAlgorithmException { packet.bufferWrite(connection); } + private void sendSSLRequestPacket(byte packetId) throws NoSuchAlgorithmException { + SSLRequestPacket packet = new SSLRequestPacket(); + packet.setPacketId(packetId); + packet.setMaxPacketSize(SystemConfig.getInstance().getMaxPacketSize()); + int charsetIndex = CharsetUtil.getCharsetDefaultIndex(SystemConfig.getInstance().getCharset()); + packet.setCharsetIndex(charsetIndex); + packet.setClientFlags(getClientFlagSha()); + packet.bufferWrite(connection); + } + + private void sendSwitchResponse(byte[] authData, byte packetId) { AuthSwitchResponsePackage packet = new AuthSwitchResponsePackage(); packet.setAuthPluginData(authData); @@ -255,7 +270,7 @@ protected void handleDataError(Exception e) { } } - private static long initClientFlags() { + private long initClientFlags() { int flag = 0; flag |= Capabilities.CLIENT_LONG_PASSWORD; boolean isEnableCapClientFoundRows = CapClientFoundRows.getInstance().isEnableCapClientFoundRows(); @@ -274,6 +289,9 @@ private static long initClientFlags() { flag |= Capabilities.CLIENT_IGNORE_SPACE; flag |= Capabilities.CLIENT_PROTOCOL_41; flag |= Capabilities.CLIENT_INTERACTIVE; + if (connection.isSupportSSL()) { + flag |= Capabilities.CLIENT_SSL; + } // flag |= Capabilities.CLIENT_SSL; flag |= Capabilities.CLIENT_IGNORE_SIGPIPE; flag |= Capabilities.CLIENT_TRANSACTIONS; @@ -323,4 +341,35 @@ protected Executor getExecutor() { return DbleServer.getInstance().getBackendExecutor(); } } + + @Override + public void consumeSingleTask(ServiceTask serviceTask) { + //The close packet can't be filtered + if (beforeHandlingTask(serviceTask) || (serviceTask.getType() == ServiceTaskType.CLOSE)) { + if (serviceTask.getType() == ServiceTaskType.NORMAL) { + final byte[] data = ((NormalServiceTask) serviceTask).getOrgData(); + handleInnerData(data); + } else if (serviceTask.getType() == ServiceTaskType.SSL) { + final byte[] data = ((SSLProtoServerTask) serviceTask).getOrgData(); + handleSSLProtoData(data); + } else { + handleSpecialInnerData((InnerServiceTask) serviceTask); + } + } + afterDispatchTask(serviceTask); + } + + private void handleSSLProtoData(byte[] data) { + final boolean prevSslHandshakeSuccess = this.connection.isSSLHandshakeSuccess(); + connection.doSSLHandShake(data); + if (this.connection.isSSLHandshakeSuccess() && !prevSslHandshakeSuccess) { + try { + final int sslHandshakeResponsePacketId = 2; + sendAuthPacket((byte) sslHandshakeResponsePacketId); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + } + } diff --git a/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLFrontAuthService.java b/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLFrontAuthService.java index 96f604da7f..f1627b7a7e 100644 --- a/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLFrontAuthService.java +++ b/src/main/java/com/actiontech/dble/services/mysqlauthenticate/MySQLFrontAuthService.java @@ -297,7 +297,7 @@ private int getServerCapabilities() { flag |= Capabilities.CLIENT_IGNORE_SPACE; flag |= Capabilities.CLIENT_PROTOCOL_41; flag |= Capabilities.CLIENT_INTERACTIVE; - if (SystemConfig.getInstance().isSupportSSL()) { + if (connection.isSupportSSL()) { flag |= Capabilities.CLIENT_SSL; } flag |= Capabilities.CLIENT_IGNORE_SIGPIPE; diff --git a/src/main/java/com/actiontech/dble/singleton/SystemParams.java b/src/main/java/com/actiontech/dble/singleton/SystemParams.java index f4b8de131a..563257afaf 100644 --- a/src/main/java/com/actiontech/dble/singleton/SystemParams.java +++ b/src/main/java/com/actiontech/dble/singleton/SystemParams.java @@ -152,9 +152,12 @@ private SystemParams() { readOnlyParams.add(new ParamInfo("routePenetrationRules", sysConfig.getRoutePenetrationRules() + "", "The config of route penetration.The default value is ''")); readOnlyParams.add(new ParamInfo("enableSessionActiveRatioStat", FrontActiveRatioStat.getInstance().isEnable() ? "1" : "0", "Whether frontend connection activity ratio statistics are enabled. The default value is 1.")); readOnlyParams.add(new ParamInfo("enableConnectionAssociateThread", ConnectionAssociateThreadManager.getInstance().isEnable() ? "1" : "0", "Whether to open frontend connection and backend connection are associated with threads. The default value is 1.")); - readOnlyParams.add(new ParamInfo("isSupportSSL", SystemConfig.getInstance().isSupportSSL() + "", "isSupportSSL in configuration")); + readOnlyParams.add(new ParamInfo("isSupportFrontSSL", SystemConfig.getInstance().isSupportFrontSSL() + "", "isSupportFrontSSL in configuration")); + readOnlyParams.add(new ParamInfo("isSupportBackSSL", SystemConfig.getInstance().isSupportBackSSL() + "", "isSupportBackSSL in configuration")); + readOnlyParams.add(new ParamInfo("getBackendMode", SystemConfig.getInstance().getBackendMode() + "", "getBackendMode in configuration")); readOnlyParams.add(new ParamInfo("isSupportOpenSSL", (SSLWrapperRegistry.getInstance(OpenSSLWrapper.PROTOCOL) != null) + "", "Whether OpenSSL is actually supported")); readOnlyParams.add(new ParamInfo("serverCertificateKeyStoreUrl", SystemConfig.getInstance().getServerCertificateKeyStoreUrl() + "", "Service certificate required of OpenSSL")); + readOnlyParams.add(new ParamInfo("clientCertificateKeyStoreUrl", SystemConfig.getInstance().getClientCertificateKeyStoreUrl() + "", "client certificate required of OpenSSL")); readOnlyParams.add(new ParamInfo("trustCertificateKeyStoreUrl", SystemConfig.getInstance().getTrustCertificateKeyStoreUrl() + "", "Trust certificate required of OpenSSL")); readOnlyParams.add(new ParamInfo("isSupportGMSSL", (SSLWrapperRegistry.getInstance(GMSslWrapper.PROTOCOL) != null) + "", "Whether GMSSL is actually supported")); readOnlyParams.add(new ParamInfo("gmsslBothPfx", SystemConfig.getInstance().getGmsslBothPfx() + "", "National secret dual certificate/private key file in PFX format")); diff --git a/src/test/java/com/actiontech/dble/optimizer/FakeConnection.java b/src/test/java/com/actiontech/dble/optimizer/FakeConnection.java index 235dd43804..7d309e52c7 100644 --- a/src/test/java/com/actiontech/dble/optimizer/FakeConnection.java +++ b/src/test/java/com/actiontech/dble/optimizer/FakeConnection.java @@ -11,6 +11,7 @@ import com.actiontech.dble.net.connection.AbstractConnection; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.NetworkChannel; /** @@ -53,4 +54,8 @@ public void businessClose(String reason) { } + @Override + protected void handleNonSSL(ByteBuffer dataBuffer) throws IOException { + + } }