Skip to content

Commit

Permalink
Merge pull request #11 from gsanchietti/db-refactor
Browse files Browse the repository at this point in the history
Db refactor
  • Loading branch information
gsanchietti authored Apr 16, 2018
2 parents 8ecf770 + 26cb49c commit 424fa20
Show file tree
Hide file tree
Showing 16 changed files with 74 additions and 114 deletions.
3 changes: 1 addition & 2 deletions athos/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ func BulkSetValidSystems() (bool, []string) {
func getValidSystems() []models.System {
var systems []models.System

db := database.Database()
db := database.Instance()
db.Preload("Subscription.SubscriptionPlan").Joins("JOIN subscriptions ON systems.subscription_id = subscriptions.id").Where("valid_until > NOW()").Find(&systems)
db.Close()

return systems
}
14 changes: 12 additions & 2 deletions athos/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,18 @@ import (
"github.com/nethesis/dartagnan/athos/configuration"
)

func Database() *gorm.DB {
db, err := gorm.Open("postgres", "sslmode=disable dbname="+configuration.Config.Database.Name+" host="+configuration.Config.Database.Host+" port="+configuration.Config.Database.Port+" user="+configuration.Config.Database.User+" password="+configuration.Config.Database.Password)
var db *gorm.DB
var err error

func Instance() *gorm.DB {
if db == nil {
Init()
}
return db
}

func Init() *gorm.DB {
db, err = gorm.Open("postgres", "sslmode=disable dbname="+configuration.Config.Database.Name+" host="+configuration.Config.Database.Host+" port="+configuration.Config.Database.Port+" user="+configuration.Config.Database.User+" password="+configuration.Config.Database.Password)
if configuration.Config.Log.Level == "debug" {
db.LogMode(true)
}
Expand Down
5 changes: 5 additions & 0 deletions athos/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/gin-gonic/gin"

"github.com/nethesis/dartagnan/athos/configuration"
"github.com/nethesis/dartagnan/athos/database"
"github.com/nethesis/dartagnan/athos/methods"
"github.com/nethesis/dartagnan/athos/middleware"
)
Expand All @@ -40,6 +41,10 @@ func main() {
flag.Parse()
configuration.Init(ConfigFilePtr)

// Init database
db := database.Init()
defer db.Close()

// init routers
router := gin.Default()
if configuration.Config.Log.Level == "debug" {
Expand Down
37 changes: 13 additions & 24 deletions athos/methods/alert.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ import (

func alertExists(SystemID int, AlertID string) (bool, models.Alert) {
var alert models.Alert
db := database.Database()
db := database.Instance()
db.Where("alert_id = ? AND system_id = ?", AlertID, SystemID).First(&alert)
db.Close()

if alert.ID == 0 {
return false, models.Alert{}
Expand All @@ -53,9 +52,8 @@ func alertExists(SystemID int, AlertID string) (bool, models.Alert) {

func cleanupStaleAlerts(creatorID string, systemID string) {
var alerts []models.Alert
db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id = ?", systemID).Find(&alerts)
db.Close()

for _, alert := range alerts {
// do not reset backup, raid and wan alerts
Expand Down Expand Up @@ -83,7 +81,7 @@ func cleanupStaleAlerts(creatorID string, systemID string) {
notifications.AlertNotification(alert, false)

// save to history
db := database.Database()
db := database.Instance()
if err := db.Save(&alertHistory).Error; err != nil {
fmt.Printf("[ERROR] Alert not moved to history: %d\n", alert.AlertID)
}
Expand Down Expand Up @@ -143,7 +141,7 @@ func SetAlert(c *gin.Context) {
notifications.AlertNotification(toSend, false)

// save to history
db := database.Database()
db := database.Instance()
if err := db.Save(&alertHistory).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "alert not moved to history", "error": err.Error()})
return
Expand All @@ -161,7 +159,7 @@ func SetAlert(c *gin.Context) {
alert.Status = json.Status

// save alert
db := database.Database()
db := database.Instance()
if err := db.Save(&alert).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "alert not updated", "error": err.Error()})
return
Expand All @@ -184,7 +182,6 @@ func SetAlert(c *gin.Context) {
return
}

db.Close()
}
} else {
if json.Status == "INIT" {
Expand All @@ -210,7 +207,7 @@ func SetAlert(c *gin.Context) {
}

// save alert
db := database.Database()
db := database.Instance()
if err := db.Save(&alert).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "alert not saved", "error": err.Error()})
return
Expand All @@ -220,7 +217,6 @@ func SetAlert(c *gin.Context) {
alert.NameI18n = utils.GetAlertHumanName(alert.AlertID, "en-US")
notifications.AlertNotification(alert, true)

db.Close()
}

c.JSON(http.StatusOK, gin.H{"status": "success"})
Expand All @@ -242,7 +238,7 @@ func UpdateAlertNote(c *gin.Context) {
return
}

db := database.Database()
db := database.Instance()
db.Where("id = ? AND system_id = ?", alertID, json.SystemID).First(&alert)

if alert.ID == 0 {
Expand All @@ -252,7 +248,6 @@ func UpdateAlertNote(c *gin.Context) {

alert.Note = json.Note
db.Save(&alert)
db.Close()

c.JSON(http.StatusOK, gin.H{"status": "success"})
}
Expand All @@ -267,9 +262,8 @@ func GetAlerts(c *gin.Context) {
limit := c.Query("limit")
offsets := utils.OffsetCalc(page, limit)

db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id = ?", systemID).Find(&alerts)
db.Close()

for _, alert := range alerts {
if utils.CanAccessAlerts(alert.System.Subscription.SubscriptionPlan) {
Expand All @@ -291,10 +285,9 @@ func GetAlerts(c *gin.Context) {

func getSystemsByCreator(creatorID string) []models.System {
var systems []models.System
db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", false)
db.Select("systems.id").Where("creator_id = ?", creatorID).Find(&systems)
db.Close()

return systems
}
Expand All @@ -305,9 +298,8 @@ func getSystemHostname(systemID int) string {
}

var result Result
db := database.Database()
db := database.Instance()
db.Raw("SELECT inventories.data->'networking'->>'fqdn' AS hostname FROM inventories WHERE system_id = ?", systemID).Scan(&result)
db.Close()

return result.Hostname
}
Expand All @@ -328,9 +320,8 @@ func GetAllAlerts(c *gin.Context) {
systemIds = append(systemIds, system.ID)
}

db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id IN (?)", systemIds).Find(&alerts)
db.Close()

for _, alert := range alerts {
if utils.CanAccessAlerts(alert.System.Subscription.SubscriptionPlan) {
Expand Down Expand Up @@ -360,9 +351,8 @@ func GetAlertHistories(c *gin.Context) {
limit := c.Query("limit")
offsets := utils.OffsetCalc(page, limit)

db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id = ?", systemID).Offset(offsets[0]).Limit(offsets[1]).Find(&alertHistories)
db.Close()

c.JSON(http.StatusOK, alertHistories)
}
Expand All @@ -373,7 +363,7 @@ func DeleteAlert(c *gin.Context) {

alertID := c.Param("alert_id")

db := database.Database()
db := database.Instance()
db.Where("id = ?", alertID).First(&alert)

if alert.ID == 0 {
Expand All @@ -391,7 +381,6 @@ func DeleteAlert(c *gin.Context) {
return
}

db.Close()

c.JSON(http.StatusOK, gin.H{"status": "success"})
}
14 changes: 4 additions & 10 deletions athos/methods/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ func GetVatPercentage(customerCountry string, customerVat string) int {
return 0
}

db := database.Database()
db := database.Instance()
db.Where("country = ?", customerCountry).First(&tax)
defer db.Close()

// Customer is from non-UE countries, no VAT applied
if tax.Country == "Other" {
Expand All @@ -67,12 +66,10 @@ func GetBilling(c *gin.Context) {
var billing models.Billing
creatorID := c.MustGet("authUser").(string)

db := database.Database()
db := database.Instance()
db.Where("creator_id = ?", creatorID).First(&billing)
defer db.Close()

if billing.ID == 0 {
db.Close()
c.JSON(http.StatusNotFound, gin.H{"message": "no billing information found!"})
return
}
Expand Down Expand Up @@ -101,8 +98,7 @@ func CreateBilling(c *gin.Context) {
Vat: json.Vat,
}

db := database.Database()
defer db.Close()
db := database.Instance()
if err := db.Create(&billing).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "billing not saved", "error": err.Error()})
return
Expand All @@ -121,12 +117,10 @@ func UpdateBilling(c *gin.Context) {
return
}

db := database.Database()
db := database.Instance()
db.Where("creator_id = ?", creatorID).First(&billing)
defer db.Close()

if billing.ID == 0 {
db.Close()
c.JSON(http.StatusNotFound, gin.H{"message": "no billing found!"})
return
}
Expand Down
12 changes: 4 additions & 8 deletions athos/methods/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ import (

func heartbeatExists(SystemID int) (bool, models.Heartbeat) {
var heartbeat models.Heartbeat
db := database.Database()
db := database.Instance()
db.Where("system_id = ?", SystemID).First(&heartbeat)
db.Close()

if heartbeat.ID == 0 {
return false, models.Heartbeat{}
Expand All @@ -64,13 +63,12 @@ func SetHeartbeat(c *gin.Context) {
heartbeat.Timestamp = time.Now().UTC()

// save current heartbeat
db := database.Database()
db := database.Instance()
if err := db.Save(&heartbeat).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "heartbeat not updated", "error": err.Error()})
return
}

db.Close()
} else {
// create heartbeat
heartbeat := models.Heartbeat{
Expand All @@ -79,12 +77,11 @@ func SetHeartbeat(c *gin.Context) {
}

// save new heartbeat
db := database.Database()
db := database.Instance()
if err := db.Save(&heartbeat).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "heartbeat not saved", "error": err.Error()})
return
}
db.Close()
}

c.JSON(http.StatusOK, gin.H{"status": "success"})
Expand All @@ -95,9 +92,8 @@ func GetHeartbeat(c *gin.Context) {
creatorID := c.MustGet("authUser").(string)
systemID := c.Param("system_id")

db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id = ?", systemID).First(&heartbeat)
db.Close()

if heartbeat.ID == 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "no heartbeat found!"})
Expand Down
15 changes: 4 additions & 11 deletions athos/methods/inventory.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ import (

func inventoryExists(SystemID int) (bool, models.Inventory) {
var inventory models.Inventory
db := database.Database()
db := database.Instance()
db.Where("system_id = ?", SystemID).First(&inventory)
db.Close()

if inventory.ID == 0 {
return false, models.Inventory{}
Expand All @@ -58,11 +57,10 @@ func SetInventory(c *gin.Context) {
system := utils.GetSystemFromUUID(json.Data.SystemID)

// prepare the db for all queries
db := database.Database()
db := database.Instance()

if err := db.Model(&system).Where("uuid = ?", json.Data.SystemID).Update("PublicIP", c.ClientIP()).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "PublicIP not updated", "error": err.Error()})
db.Close()
return
}

Expand All @@ -89,7 +87,6 @@ func SetInventory(c *gin.Context) {
// save current inventory
if err := db.Save(&inventory).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "inventory not updated", "error": err.Error()})
db.Close()
return
}

Expand All @@ -104,12 +101,10 @@ func SetInventory(c *gin.Context) {
// save new inventory
if err := db.Save(&inventory).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": "inventory not saved", "error": err.Error()})
db.Close()
return
}
}

db.Close()
c.JSON(http.StatusOK, gin.H{"status": "success"})
}

Expand All @@ -118,9 +113,8 @@ func GetInventory(c *gin.Context) {
creatorID := c.MustGet("authUser").(string)
systemID := c.Param("system_id")

db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id = ?", systemID).First(&inventory)
db.Close()

if inventory.ID == 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "no inventory found!"})
Expand All @@ -139,9 +133,8 @@ func GetInventoryHistories(c *gin.Context) {
limit := c.Query("limit")
offsets := utils.OffsetCalc(page, limit)

db := database.Database()
db := database.Instance()
db.Set("gorm:auto_preload", true).Preload("System", "creator_id = ?", creatorID).Where("system_id = ?", systemID).Offset(offsets[0]).Limit(offsets[1]).Find(&inventoryHistories)
db.Close()

if len(inventoryHistories) <= 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "no inventory histories found!"})
Expand Down
3 changes: 1 addition & 2 deletions athos/methods/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ import (
func GetSubscriptionPlans(c *gin.Context) {
var subscriptionPlans []models.SubscriptionPlan

db := database.Database()
db := database.Instance()
db.Find(&subscriptionPlans)
defer db.Close()

if len(subscriptionPlans) <= 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "no subscription plans found!"})
Expand Down
Loading

0 comments on commit 424fa20

Please sign in to comment.