Skip to content

Commit

Permalink
Merge pull request #41 from jingweno/fix_no_such_file_host
Browse files Browse the repository at this point in the history
Create known_hosts file if it doesn't exist
  • Loading branch information
owenthereal authored May 28, 2020
2 parents 7aee1fe + 00476d2 commit 07063d3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
18 changes: 18 additions & 0 deletions host/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ import (
)

func NewPromptingHostKeyCallback(stdin io.Reader, stdout io.Writer, knownHostsFilename string) (ssh.HostKeyCallback, error) {
if err := createFileIfNotExist(knownHostsFilename); err != nil {
return nil, err
}

cb, err := knownhosts.New(knownHostsFilename)
if err != nil {
return nil, err
Expand Down Expand Up @@ -299,3 +303,17 @@ func (c *Host) Run(ctx context.Context) error {
func keyType(t string) string {
return strings.ToUpper(strings.TrimPrefix(t, "ssh-"))
}

func createFileIfNotExist(file string) error {
_, err := os.Stat(file)
if os.IsNotExist(err) {
file, err := os.Create(file)
if err != nil {
return err
}

defer file.Close()
}

return nil
}
36 changes: 36 additions & 0 deletions host/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io/ioutil"
"net"
"os"
"path/filepath"
"strings"
"testing"

Expand All @@ -16,6 +17,41 @@ const (
testPublicKey = `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIN0EWrjdcHcuMfI8bGAyHPcGsAc/vd/gl5673pRkRBGY`
)

func Test_hostKeyCallbackKnowHostsFileNotExist(t *testing.T) {
dir, err := ioutil.TempDir("", "test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)

knownHostsFile := filepath.Join(dir, "known_hosts")

stdin := bytes.NewBufferString("yes\n") // Simulate typing "yes" in stdin
stdout := bytes.NewBuffer(nil)

pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
if err != nil {
t.Fatal(err)
}
fp := utils.FingerprintSHA256(pk)

cb, err := NewPromptingHostKeyCallback(stdin, stdout, knownHostsFile)
if err != nil {
t.Fatal(err)
}

addr := &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 22,
}
if err := cb("127.0.0.1:22", addr, pk); err != nil {
t.Fatal(err)
}
if !strings.Contains(stdout.String(), "ED25519 key fingerprint is "+fp) {
t.Fatalf("stdout should contain fingerprint %s: %s", fp, stdout)
}
}

func Test_hostKeyCallback(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "known_hosts")
if err != nil {
Expand Down

0 comments on commit 07063d3

Please sign in to comment.