From e1b69ab07562f9d7ac84787b733e713eccefa06f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrei=20Tudor=20C=C4=83lin?= Date: Sun, 31 Mar 2019 06:32:56 +0200 Subject: [PATCH] import standard library splice tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename "splice" to "transfer", since that's what we're testing. Enable and test unix <- tcp and unix <- unix cases. Updates #2. Signed-off-by: Andrei Tudor Călin --- zerocopy_linux.go | 47 ++++- zerocopy_linux_test.go | 409 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 396 insertions(+), 60 deletions(-) diff --git a/zerocopy_linux.go b/zerocopy_linux.go index c452413..d42bb9a 100644 --- a/zerocopy_linux.go +++ b/zerocopy_linux.go @@ -325,6 +325,11 @@ func (p *Pipe) readFrom(src io.Reader) (int64, error) { waitwrite = false waitreadagain = false ) + if lr != nil { + defer func(v *int64) { + lr.N -= *v + }(&moved) + } again: ok = false max := maxSpliceSize @@ -335,8 +340,10 @@ again: wrcerr = p.wrc.Write(func(pwfd uintptr) bool { var n int n, operr = splice(rfd, pwfd, max) - limit -= int64(n) - moved += int64(n) + if n > 0 { + limit -= int64(n) + moved += int64(n) + } if operr == unix.EINVAL { fallback = true return true @@ -383,8 +390,10 @@ again: rrcerr = rrc.Read(func(rfd uintptr) bool { var n int n, operr = splice(rfd, pwfd, max) - limit -= int64(n) - moved += int64(n) + if n > 0 { + limit -= int64(n) + moved += int64(n) + } if operr == unix.EAGAIN { if writeready { waitreadagain = true @@ -449,7 +458,9 @@ again: wrcerr = wrc.Write(func(pwfd uintptr) bool { var n int n, operr = splice(rfd, pwfd, maxSpliceSize) - moved += int64(n) + if n > 0 { + moved += int64(n) + } if operr == unix.EINVAL { fallback = true return true @@ -493,7 +504,9 @@ again: rrcerr = p.rrc.Read(func(rfd uintptr) bool { var n int n, operr = splice(rfd, pwfd, maxSpliceSize) - moved += int64(n) + if n > 0 { + moved += int64(n) + } if operr == unix.EAGAIN { if writeready { waitreadagain = true @@ -571,6 +584,11 @@ func transfer(dst io.Writer, src io.Reader) (int64, error) { } var moved int64 = 0 + if lr != nil { + defer func(v *int64) { + lr.N -= *v + }(&moved) + } for limit > 0 { max := maxSpliceSize if int64(max) > limit { @@ -584,8 +602,13 @@ func transfer(dst io.Writer, src io.Reader) (int64, error) { if inpipe == 0 && err == nil { return moved, nil } + if err != nil { + return moved, err + } n, fallback, err := splicePump(wrc, p, inpipe) - moved += int64(n) + if n > 0 { + moved += int64(n) + } if fallback { // dst doesn't support splicing, but we've already // read from src, so we need to empty the pipe, @@ -624,6 +647,7 @@ func spliceDrain(p *Pipe, rrc syscall.RawConn, max int) (int, bool, error) { if serr == unix.EAGAIN { return false } + serr = os.NewSyscallError("splice", serr) return true }) return true @@ -646,11 +670,13 @@ func splicePump(wrc syscall.RawConn, p *Pipe, inpipe int) (int, bool, error) { ) again: err := p.rrc.Read(func(prfd uintptr) bool { - wrcerr = wrc.Read(func(wfd uintptr) bool { + wrcerr = wrc.Write(func(wfd uintptr) bool { var n int n, serr = splice(prfd, wfd, inpipe) - moved += int(n) - inpipe -= int(n) + if n > 0 { + moved += int(n) + inpipe -= int(n) + } if serr == unix.EINVAL { fallback = true return true @@ -658,6 +684,7 @@ again: if serr == unix.EAGAIN { return false } + serr = os.NewSyscallError("splice", serr) return true }) return true diff --git a/zerocopy_linux_test.go b/zerocopy_linux_test.go index 93979c8..a8fdf5e 100644 --- a/zerocopy_linux_test.go +++ b/zerocopy_linux_test.go @@ -10,9 +10,14 @@ import ( "fmt" "io" "io/ioutil" + "log" "net" + "os" + "os/exec" + "strconv" "sync" "testing" + "time" "acln.ro/zerocopy" ) @@ -291,97 +296,403 @@ func TestWriteTo(t *testing.T) { } } +// TODO(acln): add test cases where we transfer to and from pipes + +// TODO(acln): add test cases for ReadFrom and WriteTo with more combinations of +// blocking source and destination file descriptors. + func TestTransfer(t *testing.T) { - ln, err := net.Listen("tcp", "localhost:0") + t.Run("tcp-to-tcp", func(t *testing.T) { testTransfer(t, "tcp", "tcp") }) + t.Run("unix-to-tcp", func(t *testing.T) { testTransfer(t, "unix", "tcp") }) + t.Run("tcp-to-unix", func(t *testing.T) { testTransfer(t, "tcp", "unix") }) + t.Run("unix-to-unix", func(t *testing.T) { testTransfer(t, "unix", "unix") }) +} + +func testTransfer(t *testing.T, upNet, downNet string) { + t.Run("simple", transferTestCase{upNet, downNet, 128, 128, 0}.test) + t.Run("multipleWrite", transferTestCase{upNet, downNet, 4096, 1 << 20, 0}.test) + t.Run("big", transferTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test) + t.Run("honorsLimitedReader", transferTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test) + t.Run("updatesLimitedReaderN", transferTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test) + t.Run("limitedReaderAtLimit", transferTestCase{upNet, downNet, 32, 128, 128}.test) + t.Run("readerAtEOF", func(t *testing.T) { testTransferReaderAtEOF(t, upNet, downNet) }) + t.Run("issue25985", func(t *testing.T) { testTransferIssue25985(t, upNet, downNet) }) +} + +type transferTestCase struct { + upNet, downNet string + + chunkSize, totalSize int + limitReadSize int +} + +func (tc transferTestCase) test(t *testing.T) { + clientUp, serverUp, err := transferTestSocketPair(tc.upNet) if err != nil { t.Fatal(err) } - defer ln.Close() + defer serverUp.Close() + cleanup, err := startTransferClient(clientUp, "w", tc.chunkSize, tc.totalSize) + if err != nil { + t.Fatal(err) + } + defer cleanup() + clientDown, serverDown, err := transferTestSocketPair(tc.downNet) + if err != nil { + t.Fatal(err) + } + defer serverDown.Close() + cleanup, err = startTransferClient(clientDown, "r", tc.chunkSize, tc.totalSize) + if err != nil { + t.Fatal(err) + } + defer cleanup() + var ( + r io.Reader = serverUp + size = tc.totalSize + ) + if tc.limitReadSize > 0 { + if tc.limitReadSize < size { + size = tc.limitReadSize + } + + r = &io.LimitedReader{ + N: int64(tc.limitReadSize), + R: serverUp, + } + defer serverUp.Close() + } + n, err := zerocopy.Transfer(serverDown, r) + serverDown.Close() + if err != nil { + t.Fatal(err) + } + if want := int64(size); want != n { + t.Errorf("want %d bytes transfered, got %d", want, n) + } + + if tc.limitReadSize > 0 { + wantN := 0 + if tc.limitReadSize > size { + wantN = tc.limitReadSize - size + } + + if n := r.(*io.LimitedReader).N; n != int64(wantN) { + t.Errorf("r.N = %d, want %d", n, wantN) + } + } +} - var clientup net.Conn - var uperr error +func testTransferReaderAtEOF(t *testing.T, upNet, downNet string) { + clientUp, serverUp, err := transferTestSocketPair(upNet) + if err != nil { + t.Fatal(err) + } + defer clientUp.Close() + clientDown, serverDown, err := transferTestSocketPair(downNet) + if err != nil { + t.Fatal(err) + } + defer clientDown.Close() - updone := make(chan struct{}) + serverUp.Close() + msg := "bye" go func() { - clientup, uperr = net.Dial(ln.Addr().Network(), ln.Addr().String()) - close(updone) + zerocopy.Transfer(serverDown, serverUp) + io.WriteString(serverDown, msg) + serverDown.Close() }() - serverup, err := ln.Accept() + buf := make([]byte, 3) + _, err = io.ReadFull(clientDown, buf) if err != nil { - t.Fatal(err) + t.Errorf("clientDown: %v", err) + } + if string(buf) != msg { + t.Errorf("clientDown got %q, want %q", buf, msg) } - defer serverup.Close() +} - <-updone - if uperr != nil { +func testTransferIssue25985(t *testing.T, upNet, downNet string) { + front, err := newLocalListener(upNet) + if err != nil { + t.Fatal(err) + } + defer front.Close() + back, err := newLocalListener(downNet) + if err != nil { t.Fatal(err) } + defer back.Close() - var clientdown net.Conn - var downerr error + var wg sync.WaitGroup + wg.Add(2) - downdone := make(chan struct{}) + proxy := func() { + src, err := front.Accept() + if err != nil { + return + } + dst, err := net.Dial(downNet, back.Addr().String()) + if err != nil { + return + } + defer dst.Close() + defer src.Close() + go func() { + zerocopy.Transfer(src, dst) + wg.Done() + }() + go func() { + zerocopy.Transfer(dst, src) + wg.Done() + }() + } - go func() { - clientdown, downerr = net.Dial(ln.Addr().Network(), ln.Addr().String()) - close(downdone) - }() + go proxy() - serverdown, err := ln.Accept() + toFront, err := net.Dial(upNet, front.Addr().String()) if err != nil { t.Fatal(err) } - defer serverdown.Close() - <-downdone - if downerr != nil { + io.WriteString(toFront, "foo") + toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { t.Fatal(err) } + defer fromProxy.Close() - msg := "hello world" + _, err = ioutil.ReadAll(fromProxy) + if err != nil { + t.Fatal(err) + } - var werr error + wg.Wait() +} - wdone := make(chan struct{}) +func BenchmarkTransfer(b *testing.B) { + b.Run("tcp-to-tcp", func(b *testing.B) { benchTransfer(b, "tcp", "tcp") }) + b.Run("unix-to-tcp", func(b *testing.B) { benchTransfer(b, "unix", "tcp") }) + b.Run("tcp-to-unix", func(b *testing.B) { benchTransfer(b, "tcp", "unix") }) + b.Run("unix-to-unix", func(b *testing.B) { benchTransfer(b, "unix", "unix") }) +} - go func() { - _, werr = clientup.Write([]byte(msg)) - clientup.Close() - close(wdone) - }() +func benchTransfer(b *testing.B, upNet, downNet string) { + for i := 0; i <= 10; i++ { + chunkSize := 1 << uint(i+10) + tc := transferTestCase{ + upNet: upNet, + downNet: downNet, + chunkSize: chunkSize, + } + + b.Run(strconv.Itoa(chunkSize), tc.bench) + } +} + +func (tc transferTestCase) bench(b *testing.B) { + // To benchmark the generic transfer code path, set this to false. + useTransfer := true + + clientUp, serverUp, err := transferTestSocketPair(tc.upNet) + if err != nil { + b.Fatal(err) + } + defer serverUp.Close() + + cleanup, err := startTransferClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) + if err != nil { + b.Fatal(err) + } + defer cleanup() - var rerr error + clientDown, serverDown, err := transferTestSocketPair(tc.downNet) + if err != nil { + b.Fatal(err) + } + defer serverDown.Close() - rdone := make(chan struct{}) + cleanup, err = startTransferClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) + if err != nil { + b.Fatal(err) + } + defer cleanup() + + b.SetBytes(int64(tc.chunkSize)) + b.ResetTimer() + if useTransfer { + _, err := zerocopy.Transfer(serverDown, serverUp) + if err != nil { + b.Fatal(err) + } + } else { + type onlyReader struct { + io.Reader + } + _, err := io.Copy(serverDown, onlyReader{serverUp}) + if err != nil { + b.Fatal(err) + } + } +} + +func transferTestSocketPair(network string) (client, server net.Conn, err error) { + ln, err := newLocalListener(network) + if err != nil { + return nil, nil, err + } + defer ln.Close() + var cerr, serr error + acceptDone := make(chan struct{}) go func() { - defer close(rdone) - buf := make([]byte, len(msg)) - _, rerr = io.ReadFull(clientdown, buf) - if rerr != nil { - return + server, serr = ln.Accept() + acceptDone <- struct{}{} + }() + client, cerr = net.Dial(ln.Addr().Network(), ln.Addr().String()) + <-acceptDone + if cerr != nil { + if server != nil { + server.Close() } - if string(buf) != msg { - rerr = fmt.Errorf("got %q, want %q", string(buf), msg) + return nil, nil, cerr + } + if serr != nil { + if client != nil { + client.Close() } + return nil, nil, serr + } + return client, server, nil +} + +func startTransferClient(conn net.Conn, op string, chunkSize, totalSize int) (func(), error) { + f, err := conn.(interface{ File() (*os.File, error) }).File() + if err != nil { + return nil, err + } + + cmd := exec.Command(os.Args[0], os.Args[1:]...) + cmd.Env = []string{ + "ZEROCOPY_NET_TEST=1", + "ZEROCOPY_NET_TEST_OP=" + op, + "ZEROCOPY_NET_TEST_CHUNK_SIZE=" + strconv.Itoa(chunkSize), + "ZEROCOPY_NET_TEST_TOTAL_SIZE=" + strconv.Itoa(totalSize), + } + cmd.ExtraFiles = append(cmd.ExtraFiles, f) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return nil, err + } + + donec := make(chan struct{}) + go func() { + cmd.Wait() + conn.Close() + f.Close() + close(donec) }() - _, err = zerocopy.Transfer(serverdown, serverup) + return func() { + select { + case <-donec: + case <-time.After(5 * time.Second): + log.Printf("killing transfer client after 5 second shutdown timeout") + cmd.Process.Kill() + select { + case <-donec: + case <-time.After(5 * time.Second): + log.Printf("transfer client didn't die after 10 seconds") + } + } + }, nil +} + +func init() { + if os.Getenv("ZEROCOPY_NET_TEST") == "" { + return + } + defer os.Exit(0) + + f := os.NewFile(uintptr(3), "transfer-test-conn") + defer f.Close() + + conn, err := net.FileConn(f) if err != nil { - t.Fatal(err) + log.Fatal(err) } - <-wdone - <-rdone - if werr != nil { - t.Error(werr) + + var chunkSize int + if chunkSize, err = strconv.Atoi(os.Getenv("ZEROCOPY_NET_TEST_CHUNK_SIZE")); err != nil { + log.Fatal(err) } - if rerr != nil { - t.Error(rerr) + buf := make([]byte, chunkSize) + + var totalSize int + if totalSize, err = strconv.Atoi(os.Getenv("ZEROCOPY_NET_TEST_TOTAL_SIZE")); err != nil { + log.Fatal(err) + } + + var fn func([]byte) (int, error) + switch op := os.Getenv("ZEROCOPY_NET_TEST_OP"); op { + case "r": + fn = conn.Read + case "w": + defer conn.Close() + + fn = conn.Write + default: + log.Fatalf("unknown op %q", op) + } + + var n int + for count := 0; count < totalSize; count += n { + if count+chunkSize > totalSize { + buf = buf[:totalSize-count] + } + + var err error + if n, err = fn(buf); err != nil { + return + } } } +func newLocalListener(network string) (net.Listener, error) { + switch network { + case "tcp": + if ln, err := net.Listen("tcp4", "127.0.0.1:0"); err == nil { + return ln, nil + } + return net.Listen("tcp6", "[::1]:0") + case "tcp4": + return net.Listen("tcp4", "127.0.0.1:0") + case "tcp6": + return net.Listen("tcp6", "[::1]:0") + case "unix", "unixpacket": + return net.Listen(network, testUnixAddr()) + } + return nil, fmt.Errorf("%s is not supported", network) +} + +// testUnixAddr uses ioutil.TempFile to get a name that is unique. +func testUnixAddr() string { + f, err := ioutil.TempFile("", "zerocopy-nettest") + if err != nil { + panic(err) + } + addr := f.Name() + f.Close() + os.Remove(addr) + return addr +} + func TestSetBufferSize(t *testing.T) { n := 32 * 4096 p, err := zerocopy.NewPipe() @@ -401,5 +712,3 @@ func TestSetBufferSize(t *testing.T) { t.Fatalf("got %d, want %d", got, n) } } - -// TODO(acln): grab splice tests and benchmarks from the stdlib, use them here