diff --git a/Android/app/src/go/intra/split/direct_split.go b/Android/app/src/go/intra/split/direct_split.go index 16e3e730..88b31f8c 100644 --- a/Android/app/src/go/intra/split/direct_split.go +++ b/Android/app/src/go/intra/split/direct_split.go @@ -54,13 +54,8 @@ func (s *splitter) Write(b []byte) (int, error) { // Setting `used` to true ensures that this code only runs once per socket. s.used = true - b1, b2 := splitHello(b) - n1, err := conn.Write(b1) - if err != nil { - return n1, err - } - n2, err := conn.Write(b2) - return n1 + n2, err + n, _, err := splitHello(b, conn) + return n, err } func (s *splitter) ReadFrom(reader io.Reader) (bytes int64, err error) { diff --git a/Android/app/src/go/intra/split/retrier.go b/Android/app/src/go/intra/split/retrier.go index 7367ec9e..07e9a208 100644 --- a/Android/app/src/go/intra/split/retrier.go +++ b/Android/app/src/go/intra/split/retrier.go @@ -16,6 +16,7 @@ package split import ( "context" + "encoding/binary" "errors" "io" "localhost/Intra/Android/app/src/go/logging" @@ -165,8 +166,9 @@ func (r *retrier) Read(buf []byte) (n int, err error) { } // Read failed. Retry. n, err = r.retry(buf) + } else { + logging.Debug("SplitRetry(retrier.Read) - direct conn succeeded, no need to split") } - logging.Debug("SplitRetry(retrier.Read) - direct conn succeeded, no need to split") close(r.retryCompleteFlag) // Unset read deadline. r.conn.SetReadDeadline(time.Time{}) @@ -186,14 +188,12 @@ func (r *retrier) retry(buf []byte) (n int, err error) { return } r.conn = newConn.(*net.TCPConn) - first, second := splitHello(r.hello) - r.stats.Split = int16(len(first)) - if _, err = r.conn.Write(first); err != nil { - return - } - if _, err = r.conn.Write(second); err != nil { + _, split, err := splitHello(r.hello, r.conn) + r.stats.Split = int16(split) + if err != nil { return } + // While we were creating the new socket, the caller might have called CloseRead // or CloseWrite on the old socket. Copy that state to the new socket. // CloseRead and CloseWrite are idempotent, so this is safe even if the user's @@ -219,22 +219,77 @@ func (r *retrier) CloseRead() error { return r.conn.CloseRead() } -func splitHello(hello []byte) ([]byte, []byte) { - if len(hello) == 0 { - return hello, hello +func getTLSClientHelloRecordLen(h []byte) (uint16, bool) { + if len(h) < 5 { + return 0, false + } + + const ( + TYPE_HANDSHAKE byte = 22 + VERSION_TLS10 uint16 = 0x0301 + VERSION_TLS11 uint16 = 0x0302 + VERSION_TLS12 uint16 = 0x0303 + VERSION_TLS13 uint16 = 0x0304 + ) + + if h[0] != TYPE_HANDSHAKE { + return 0, false + } + + ver := binary.BigEndian.Uint16(h[1:3]) + if ver != VERSION_TLS10 && ver != VERSION_TLS11 && + ver != VERSION_TLS12 && ver != VERSION_TLS13 { + return 0, false + } + + return binary.BigEndian.Uint16(h[3:5]), true +} + +func splitHello(hello []byte, w io.Writer) (n int, splitLen int, err error) { + if len(hello) <= 1 { + n, err = w.Write(hello) + return } + const ( - MIN_SPLIT int = 32 + MIN_SPLIT int = 6 MAX_SPLIT int = 64 ) // Random number in the range [MIN_SPLIT, MAX_SPLIT] - s := MIN_SPLIT + rand.Intn(MAX_SPLIT+1-MIN_SPLIT) + // splitLen includes 5 bytes of TLS header + splitLen = MIN_SPLIT + rand.Intn(MAX_SPLIT+1-MIN_SPLIT) limit := len(hello) / 2 - if s > limit { - s = limit + if splitLen > limit { + splitLen = limit } - return hello[:s], hello[s:] + + recordLen, ok := getTLSClientHelloRecordLen(hello) + recordSplitLen := splitLen - 5 + if !ok || recordSplitLen <= 0 || recordSplitLen >= int(recordLen) { + // Do TCP split if hello is not a valid TLS Client Hello, or cannot be fragmented + n, err = w.Write(hello[:splitLen]) + if err == nil { + var m int + m, err = w.Write(hello[splitLen:]) + n += m + } + return + } + + pkt := hello[:splitLen] + binary.BigEndian.PutUint16(pkt[3:5], uint16(recordSplitLen)) + if n, err = w.Write(pkt); err != nil { + return + } + + pkt = hello[splitLen-5:] + copy(pkt, hello[:5]) + binary.BigEndian.PutUint16(pkt[3:5], recordLen-uint16(recordSplitLen)) + var m int + m, err = w.Write(pkt) + n += m + return } // Write-related functions diff --git a/Android/app/src/go/intra/split/retrier_test.go b/Android/app/src/go/intra/split/retrier_test.go index c68c542a..d9de633e 100644 --- a/Android/app/src/go/intra/split/retrier_test.go +++ b/Android/app/src/go/intra/split/retrier_test.go @@ -194,7 +194,7 @@ func (s *setup) checkStats(bytes int32, chunks int16, timeout bool) { if r.Timeout != timeout { s.t.Errorf("Expected timeout to be %t", timeout) } - if r.Split < 32 || r.Split > 64 { + if r.Split < 6 || r.Split > 64 { s.t.Errorf("Unexpected split: %d", r.Split) } }