From ccf00d105db9b4ea1816dbaadda8681ee1db0824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ant=C3=B3nio=20Meireles?= Date: Thu, 14 Jan 2016 12:05:50 +0000 Subject: [PATCH] simplify a bit VM's boot handling. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: António Meireles --- globals.go | 3 -- halt.go | 5 ++-- helpers.go | 30 ++++++++----------- run.go | 86 +++++++++++++++++++----------------------------------- 4 files changed, 45 insertions(+), 79 deletions(-) diff --git a/globals.go b/globals.go index 36ecb0b..d3e7ae0 100644 --- a/globals.go +++ b/globals.go @@ -16,7 +16,6 @@ package main import ( - "sync" "time" "github.com/spf13/viper" @@ -47,10 +46,8 @@ type ( PublicIP string CreatedAt time.Time publicIP chan string - wg sync.WaitGroup errch chan error done chan bool - exitStatus error } // NetworkInterface ... NetworkInterface struct { diff --git a/halt.go b/halt.go index c8376bb..fd6b93d 100644 --- a/halt.go +++ b/halt.go @@ -81,7 +81,7 @@ func (vm VMInfo) halt() (err error) { } if p, ee := os.FindProcess(vm.Pid); ee == nil { log.Println("hard kill...") - if err = p.Kill(); err != nil { + if err = p.Signal(os.Interrupt); err != nil { return } } @@ -102,9 +102,10 @@ func (vm VMInfo) halt() (err error) { } // wait until it's _really_ dead, but not forever for { + timeout := time.After(30 * time.Second) select { // unmounts may take a bit in slow drives... - case <-time.After(30 * time.Second): + case <-timeout: return fmt.Errorf(fmt.Sprintf("'%s' didn't shutdown normally "+ "after 30s (!)... ", vm.Name)) case <-time.Tick(100 * time.Millisecond): diff --git a/helpers.go b/helpers.go index 5423e3e..c998b42 100644 --- a/helpers.go +++ b/helpers.go @@ -327,17 +327,12 @@ func (vm *VMInfo) isActive() bool { func (vm *VMInfo) metadataService() (endpoint string, err error) { var ( - free net.Listener - sentCC, sentSSHk, foundGuestIP sync.Once - mux, root = http.NewServeMux(), "/" + vm.Name - rIP = func(s string) string { - return strings.Split(s, ":")[0] - } - isAllowed = func(origin string, w http.ResponseWriter) bool { + free net.Listener + foundGuestIP sync.Once + mux, root = http.NewServeMux(), "/" + vm.Name + rIP = func(s string) string { return strings.Split(s, ":")[0] } + isAllowed = func(origin string, w http.ResponseWriter) bool { if strings.HasPrefix(origin, "192.168.64.") { - foundGuestIP.Do(func() { - vm.publicIP <- origin - }) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusOK) return true @@ -352,21 +347,18 @@ func (vm *VMInfo) metadataService() (endpoint string, err error) { return } - vm.wg.Add(1) - if vm.CloudConfig != "" && vm.CClocation == Local { var txt []byte if txt, err = ioutil.ReadFile(vm.CloudConfig); err != nil { return } - vm.wg.Add(1) mux.HandleFunc(root+"/cloud-config", func(w http.ResponseWriter, r *http.Request) { if isAllowed(rIP(r.RemoteAddr), w) { w.Write(txt) - sentCC.Do(func() { - vm.wg.Done() + foundGuestIP.Do(func() { + vm.publicIP <- rIP(r.RemoteAddr) }) } }) @@ -376,9 +368,11 @@ func (vm *VMInfo) metadataService() (endpoint string, err error) { func(w http.ResponseWriter, r *http.Request) { if isAllowed(rIP(r.RemoteAddr), w) { w.Write([]byte(vm.InternalSSHauthKey)) - sentSSHk.Do(func() { - vm.wg.Done() - }) + if !(vm.CloudConfig != "" && vm.CClocation == Local) { + foundGuestIP.Do(func() { + vm.publicIP <- rIP(r.RemoteAddr) + }) + } } }) mux.HandleFunc(root+"/hostname", diff --git a/run.go b/run.go index 441e78b..989bb18 100644 --- a/run.go +++ b/run.go @@ -106,7 +106,7 @@ func vmBootstrap(args *viper.Viper) (vm *VMInfo, err error) { vm.Cpus = args.GetInt("cpus") vm.Extra = args.GetString("extra") vm.SSHkey = args.GetString("sshkey") - vm.Root, vm.Pid, vm.exitStatus = -1, -1, nil + vm.Root, vm.Pid = -1, -1 vm.Name, vm.UUID = args.GetString("name"), args.GetString("uuid") @@ -225,42 +225,25 @@ func (running *sessionContext) boot(slt int, rawArgs *viper.Viper) (err error) { } go func() { - for { - select { - case <-time.After(30 * time.Second): - if p, ee := os.FindProcess(c.Process.Pid); ee == nil { - p.Kill() - } - vm.errch <- fmt.Errorf("Unable to grab VM's pid and IP after " + - "30s (!)... Aborting") - return - case ip := <-vm.publicIP: - vm.Pid = c.Process.Pid - vm.PublicIP = ip - vm.storeConfig() - close(vm.publicIP) + timeout := time.After(30 * time.Second) + select { + case <-timeout: + if p, ee := os.FindProcess(c.Process.Pid); ee == nil { + p.Signal(os.Interrupt) + } + vm.errch <- fmt.Errorf("Unable to grab VM's IP after " + + "30s (!)... Aborting") + case ip := <-vm.publicIP: + vm.Pid, vm.PublicIP = c.Process.Pid, ip + if err = vm.storeConfig(); err != nil { + vm.errch <- err + } else { if vm.Detached { log.Printf("started '%s' in background with IP %v and "+ "PID %v\n", vm.Name, vm.PublicIP, c.Process.Pid) } - return - default: - time.Sleep(100 * time.Millisecond) - } - } - }() - - go func() { - defer close(vm.done) - for { - select { - case err := <-vm.errch: - if err != nil { - return - } - default: - vm.wg.Wait() - return + close(vm.publicIP) + close(vm.done) } } }() @@ -268,39 +251,30 @@ func (running *sessionContext) boot(slt int, rawArgs *viper.Viper) (err error) { go func() { if !vm.Detached { c.Stdout, c.Stdin, c.Stderr = os.Stdout, os.Stdin, os.Stderr - err = c.Run() + vm.errch <- c.Run() + } else if err = c.Start(); err != nil { + vm.errch <- err } else { - err = c.Start() - go func() { - for { - select { - case <-vm.done: - case <-vm.errch: - return - default: - if err = c.Wait(); err != nil { - log.Println(err) - vm.errch <- fmt.Errorf("VM exited with error" + - "while starting in background") - } - } + select { + default: + if err = c.Wait(); err != nil { + log.Println(err) + vm.errch <- fmt.Errorf("VM exited with error " + + "while attempting to start in background") } - }() + case <-vm.errch: + } } - vm.errch <- err }() for { select { case <-vm.done: if vm.Detached { - return vm.exitStatus - } - case exit := <-vm.errch: - vm.exitStatus = exit - if exit != nil || (vm.PublicIP != "" && vm.Pid != -1) { - return vm.exitStatus + return } + case err = <-vm.errch: + return } time.Sleep(250 * time.Millisecond) }