diff --git a/cmd/clone.go b/cmd/clone.go index 478172b..fd1ede3 100644 --- a/cmd/clone.go +++ b/cmd/clone.go @@ -43,6 +43,7 @@ to quickly create a Cobra application.`, dump_name, _ := cmd.Flags().GetString("dump_name") databaseToCloneName, _ := cmd.Flags().GetString("from") importDatabaseName, _ := cmd.Flags().GetString("to") + shouldSwitchDb, _ := cmd.Flags().GetBool("switch") dump_name = strings.TrimRight(dump_name, ".sql") @@ -52,22 +53,49 @@ to quickly create a Cobra application.`, log.Fatalf("error opening file: %s", ferr) } - Clone(file, databaseToCloneName, importDatabaseName, dump_dir, dump_name) + cloner := &DBCloner{ + file: file, + cloneFrom: databaseToCloneName, + cloneTo: importDatabaseName, + dumpDir: dump_dir, + dumpName: dump_name, + } + + if (shouldSwitchDb) { + path, _ := os.Getwd() + cloner.CloneAndSwitch(path) + } else { + cloner.Clone() + } file.Close() }, } -func Clone(file *os.File, databaseToCloneName string, importDatabaseName string, dump_dir string, dump_name string) { +type DBCloner struct { + file *os.File + cloneFrom string + cloneTo string + dumpDir string + dumpName string +} + +func (this *DBCloner) CloneAndSwitch(environmentPath string) { + this.Clone() + + RunSwitch(environmentPath, this.cloneTo) +} + +func (this *DBCloner) Clone() { fmt.Println("Cloning database") - dumpDatabase(file, databaseToCloneName) + dumpDatabase(this.file, this.cloneFrom) - importDatabase(importDatabaseName, dump_dir, dump_name, file) + importDatabase(this.cloneTo, this.dumpDir, this.dumpName, this.file) fmt.Println("Cleaning up") - removeDumpFiles(dump_dir, dump_name) + removeDumpFiles(this.dumpDir, this.dumpName) - fmt.Println(fmt.Sprintf("%s successfully cloned from %s", importDatabaseName, databaseToCloneName)) + fmt.Println(fmt.Sprintf("%s successfully cloned from %s", this.cloneTo, this.cloneFrom)) } func removeDumpFiles(dump_dir string, dump_name string) { @@ -132,7 +160,7 @@ func addDumpToStdin(importCmd *exec.Cmd, file *os.File) { _, err = io.WriteString(stdin, string(bytes)) if err != nil { - log.Fatal(err.Error()) + log.Fatal(err) } stdin.Close() @@ -162,4 +190,5 @@ func init() { cloneCmd.Flags().String("dump_name", "dump", "Specify the name of the dump file") cloneCmd.Flags().String("to", "cloned", "Specify name of new database") cloneCmd.Flags().String("from", viper.GetString("database.database"), "Specify name of database to clone") + cloneCmd.Flags().Bool("switch", true, "Specify whether to switch databases in environment") } diff --git a/cmd/clone_test.go b/cmd/clone_test.go index 8df8f79..c64ca2e 100644 --- a/cmd/clone_test.go +++ b/cmd/clone_test.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/spf13/viper" + "io/ioutil" "log" "os" "testing" @@ -10,30 +11,73 @@ import ( ) func TestCopyExistingDatabase(t *testing.T) { - db, err := sql.Open("mysql", "root:secret@tcp(localhost:3310)/original") + setUpOriginalDatabase() + setUpConfiguration() + ioutil.WriteFile("../test_fixtures/.env", []byte("DB_DATABASE=fakedb1"), 0644) + + file, _ := os.OpenFile("./test_dump.sql", os.O_RDWR|os.O_CREATE, 0644) + + cloner := &DBCloner{ + file: file, + cloneFrom: "original", + cloneTo: "new_db", + dumpDir: ".", + dumpName: "test_dump", + } - if (err != nil) { - log.Fatalf("could not connect to database") + cloner.Clone() + + err, databases := getListOfDatabases() + + _, ok := databases["new_db"] + + if !ok { + log.Fatalln("Expected to find new_db in list of databases") } - _, err = db.Exec("drop table if exists test;") - _, err = db.Exec("drop database if exists new_db;") - _, err = db.Exec("create table test (column1 int not 1null, column2 int not null);") - _, err = db.Exec("insert into test (column1, column2) VALUE (1,2);") + _, err = os.Stat("./test_dump.sql") - db.Close() + if err == nil { + log.Fatalln("dump file was not removed") + } - file, _ := os.OpenFile("./test_dump.sql", os.O_RDWR|os.O_CREATE, 0644) + _, err = os.Stat("./test_dump.sql.bak") - viper.Set("database.database", "original") - viper.Set("database.host", "127.0.0.1") - viper.Set("database.username", "root") - viper.Set("database.password", "secret") - viper.Set("database.port", "3310") + if err == nil { + log.Fatalln("dump backup file was not removed") + } - Clone(file, "original", "new_db", ".", "test_dump") + contents, _ := ioutil.ReadFile("../test_fixtures/.env") + if string(contents) != "DB_DATABASE=fakedb1" { + log.Fatalf("content of env were incorrectly changed, received %s", string(contents)) + } +} - db, _ = sql.Open("mysql", "root:secret@tcp(localhost:3310)/") +func TestDatabaseIsSwitchedInEnvFile(t *testing.T) { + setUpOriginalDatabase() + setUpConfiguration() + ioutil.WriteFile("../test_fixtures/.env", []byte("DB_DATABASE=original"), 0644) + + file, _ := os.OpenFile("./test_dump.sql", os.O_RDWR|os.O_CREATE, 0644) + + cloner := &DBCloner{ + file: file, + cloneFrom: "original", + cloneTo: "new_db", + dumpDir: ".", + dumpName: "test_dump", + } + + cloner.CloneAndSwitch("../test_fixtures/") + + contents, _ := ioutil.ReadFile("../test_fixtures/.env") + if string(contents) != "DB_DATABASE=new_db" { + log.Fatalf("content of env were not changed, received %s", string(contents)) + } +} + +func getListOfDatabases() (error, map[string]int) { + db, err := sql.Open("mysql", "root:secret@tcp(localhost:3310)/") rows, _ := db.Query("Show databases") @@ -45,22 +89,28 @@ func TestCopyExistingDatabase(t *testing.T) { databases[db] = 1 } + return err, databases +} - _, ok := databases["new_db"] - - if (!ok) { - log.Fatalln("Expected to find new_db in list of databases") - } +func setUpConfiguration() { + viper.Set("database.database", "original") + viper.Set("database.host", "127.0.0.1") + viper.Set("database.username", "root") + viper.Set("database.password", "secret") + viper.Set("database.port", "3310") +} - _, err = os.Stat("./test_dump.sql") +func setUpOriginalDatabase() { + db, err := sql.Open("mysql", "root:secret@tcp(localhost:3310)/original") - if (err == nil) { - log.Fatalln("dump file was not removed") + if err != nil { + log.Fatalf("could not connect to database") } - _, err = os.Stat("./test_dump.sql.bak") + _, err = db.Exec("drop table if exists test;") + _, err = db.Exec("drop database if exists new_db;") + _, err = db.Exec("create table test (column1 int not null, column2 int not null);") + _, err = db.Exec("insert into test (column1, column2) VALUE (1,2);") - if (err == nil) { - log.Fatalln("dump backup file was not removed") - } + db.Close() }