diff --git a/host/host.go b/host/host.go index 9cad584bf..89f60404a 100644 --- a/host/host.go +++ b/host/host.go @@ -25,6 +25,10 @@ import ( ) func NewPromptingHostKeyCallback(stdin io.Reader, stdout io.Writer, knownHostsFilename string) (ssh.HostKeyCallback, error) { + if err := createFileIfNotExist(knownHostsFilename); err != nil { + return nil, err + } + cb, err := knownhosts.New(knownHostsFilename) if err != nil { return nil, err @@ -299,3 +303,17 @@ func (c *Host) Run(ctx context.Context) error { func keyType(t string) string { return strings.ToUpper(strings.TrimPrefix(t, "ssh-")) } + +func createFileIfNotExist(file string) error { + _, err := os.Stat(file) + if os.IsNotExist(err) { + file, err := os.Create(file) + if err != nil { + return err + } + + defer file.Close() + } + + return nil +} diff --git a/host/host_test.go b/host/host_test.go index b97306a5b..07c38c100 100644 --- a/host/host_test.go +++ b/host/host_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net" "os" + "path/filepath" "strings" "testing" @@ -16,6 +17,41 @@ const ( testPublicKey = `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIN0EWrjdcHcuMfI8bGAyHPcGsAc/vd/gl5673pRkRBGY` ) +func Test_hostKeyCallbackKnowHostsFileNotExist(t *testing.T) { + dir, err := ioutil.TempDir("", "test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + knownHostsFile := filepath.Join(dir, "known_hosts") + + stdin := bytes.NewBufferString("yes\n") // Simulate typing "yes" in stdin + stdout := bytes.NewBuffer(nil) + + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey)) + if err != nil { + t.Fatal(err) + } + fp := utils.FingerprintSHA256(pk) + + cb, err := NewPromptingHostKeyCallback(stdin, stdout, knownHostsFile) + if err != nil { + t.Fatal(err) + } + + addr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 22, + } + if err := cb("127.0.0.1:22", addr, pk); err != nil { + t.Fatal(err) + } + if !strings.Contains(stdout.String(), "ED25519 key fingerprint is "+fp) { + t.Fatalf("stdout should contain fingerprint %s: %s", fp, stdout) + } +} + func Test_hostKeyCallback(t *testing.T) { tmpfile, err := ioutil.TempFile("", "known_hosts") if err != nil {