diff --git a/starttls/ftp.go b/starttls/ftp.go index fb23acf..b178d71 100644 --- a/starttls/ftp.go +++ b/starttls/ftp.go @@ -24,41 +24,81 @@ import ( "strconv" ) -func dumpTLSConnStateFromFTP(dialer Dialer, address string, config *tls.Config) (*tls.ConnectionState, error) { - c, err := dialer.Dial("tcp", address) +type FTPCtx struct { + tcpConn *net.TCPConn + tlsConn *tls.Conn + dialFunc func(dialer Dialer, address string) (net.Conn, error) +} + +func dumpTLSConnStateFromFTP(dialer Dialer, address string, config *tls.Config, explicitTLS bool) (*tls.ConnectionState, error) { + + ctx := FTPCtx{} + + if explicitTLS { + ctx.dialFunc = func(dialer Dialer, address string) (net.Conn, error) { + return dialer.Dial("tcp", address) + } + } else { + ctx.dialFunc = func(dialer Dialer, address string) (net.Conn, error) { + tlsDialer := &tls.Dialer{ + NetDialer: dialer.(*net.Dialer), + Config: config, + } + return tlsDialer.Dial("tcp", address) + } + } + + c, err := ctx.dialFunc(dialer, address) if err != nil { return nil, err } - conn := c.(*net.TCPConn) + if _, err = checkServiceReady(c); err != nil { + return nil, err + } + + if explicitTLS { + ctx.tcpConn = c.(*net.TCPConn) + } else { + ctx.tlsConn = c.(*tls.Conn) + } + + if explicitTLS { + if _, err := authTLS(ctx.tcpConn); err != nil { + return nil, err + } + ctx.tlsConn = tls.Client(ctx.tcpConn, config) + ctx.tlsConn.Handshake() + } + + state := ctx.tlsConn.ConnectionState() + return &state, nil +} + +func checkServiceReady(conn net.Conn) (int, error) { status, err := readFTP(conn) if err != nil { - return nil, err + return status, err } if status != 220 { - return nil, fmt.Errorf("FTP server responded with status %d, was expecting 220", status) + return status, fmt.Errorf("FTP server responded with status %d, was expecting 220", status) } + return status, nil +} +func authTLS(conn *net.TCPConn) (int, error) { fmt.Fprintf(conn, "AUTH TLS\r\n") - status, err = readFTP(conn) + status, err := readFTP(conn) if err != nil { - return nil, err + return status, err } if status != 234 { - return nil, fmt.Errorf("FTP server responded with status %d, was expecting 234", status) + return status, fmt.Errorf("FTP server responded with status %d, was expecting 234", status) } - - tlsConn := tls.Client(conn, config) - err = tlsConn.Handshake() - if err != nil { - return nil, err - } - - state := tlsConn.ConnectionState() - return &state, nil + return status, nil } -func readFTP(conn *net.TCPConn) (int, error) { +func readFTP(conn net.Conn) (int, error) { reader := bufio.NewReader(conn) response, err := reader.ReadString('\n') if err != nil { diff --git a/starttls/starttls.go b/starttls/starttls.go index 4f722e6..4c7279e 100644 --- a/starttls/starttls.go +++ b/starttls/starttls.go @@ -33,7 +33,7 @@ import ( ) // Protocols are the names of supported protocols -var Protocols = []string{"mysql", "postgres", "psql", "smtp", "ldap", "ftp", "imap"} +var Protocols = []string{"mysql", "postgres", "psql", "smtp", "ldap", "ftp", "ftps", "imap"} type connectResult struct { state *tls.ConnectionState @@ -230,7 +230,11 @@ func GetConnectionState(startTLSType, connectName, connectTo, identity, clientCe res <- connectResult{&state, nil} case "ftp": addr := withDefaultPort(connectTo, 21) - state, err = dumpTLSConnStateFromFTP(dialer, addr, tlsConfig) + state, err = dumpTLSConnStateFromFTP(dialer, addr, tlsConfig, true) + res <- connectResult{state, err} + case "ftps": + addr := withDefaultPort(connectTo, 990) + state, err = dumpTLSConnStateFromFTP(dialer, addr, tlsConfig, false) res <- connectResult{state, err} case "imap": addr := withDefaultPort(connectTo, 143)