diff --git a/.changelog/123.txt b/.changelog/123.txt new file mode 100644 index 0000000..ad06534 --- /dev/null +++ b/.changelog/123.txt @@ -0,0 +1,3 @@ +```release-note:bug +TCP, TLS outputs: Close connections gracefully to avoid data loss. +``` diff --git a/internal/output/tcp/tcp.go b/internal/output/tcp/tcp.go index 602922b..4228310 100644 --- a/internal/output/tcp/tcp.go +++ b/internal/output/tcp/tcp.go @@ -6,6 +6,8 @@ package tcp import ( "context" + "errors" + "io" "net" "time" @@ -18,7 +20,7 @@ func init() { type Output struct { opts *output.Options - conn net.Conn + conn *net.TCPConn } func New(opts *output.Options) (output.Output, error) { @@ -33,7 +35,7 @@ func (o *Output) DialContext(ctx context.Context) error { return err } - o.conn = conn + o.conn = conn.(*net.TCPConn) return nil } @@ -42,10 +44,29 @@ func (o *Output) Conn() net.Conn { } func (o *Output) Close() error { - if o.conn == nil { - return nil + if o.conn != nil { + if err := o.conn.CloseWrite(); err != nil { + return err + } + + // drain to facilitate graceful close on the other side + deadline := time.Now().Add(5 * time.Second) + if err := o.conn.SetReadDeadline(deadline); err != nil { + return err + } + buffer := make([]byte, 1024) + for { + _, err := o.conn.Read(buffer) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + } + + return o.conn.Close() } - return o.conn.Close() + return nil } func (o *Output) Write(b []byte) (int, error) { diff --git a/internal/output/tls/tls.go b/internal/output/tls/tls.go index ada2fc6..f146ade 100644 --- a/internal/output/tls/tls.go +++ b/internal/output/tls/tls.go @@ -7,6 +7,8 @@ package tcp import ( "context" "crypto/tls" + "errors" + "io" "net" "time" @@ -19,7 +21,7 @@ func init() { type Output struct { opts *output.Options - conn net.Conn + conn *tls.Conn } func New(opts *output.Options) (output.Output, error) { @@ -39,12 +41,31 @@ func (o *Output) DialContext(ctx context.Context) error { return err } - o.conn = conn + o.conn = conn.(*tls.Conn) return nil } func (o *Output) Close() error { if o.conn != nil { + if err := o.conn.CloseWrite(); err != nil { + return err + } + + // drain to facilitate graceful close on the other side + deadline := time.Now().Add(5 * time.Second) + if err := o.conn.SetReadDeadline(deadline); err != nil { + return err + } + buffer := make([]byte, 1024) + for { + _, err := o.conn.Read(buffer) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + } + return o.conn.Close() } return nil