From 4750331007ad4a8465e01b8309b0211f33e3ea80 Mon Sep 17 00:00:00 2001 From: Lewis Bayfield Date: Fri, 11 Sep 2020 17:18:33 +0100 Subject: [PATCH] Fix for overriding struct when running associations --- example/feature_demo/demo_types.pb.gorm.go | 6 +- example/user/user.pb.gorm.go | 4 +- plugin/handlergen.go | 85 ++++++++++++---------- 3 files changed, 52 insertions(+), 43 deletions(-) diff --git a/example/feature_demo/demo_types.pb.gorm.go b/example/feature_demo/demo_types.pb.gorm.go index 97b88f00..405c63d1 100644 --- a/example/feature_demo/demo_types.pb.gorm.go +++ b/example/feature_demo/demo_types.pb.gorm.go @@ -4047,7 +4047,7 @@ func DefaultStrictUpdateTestAssocHandlerReplace(ctx context.Context, in *TestAss return nil, err } } - if err = db.Model(&ormObj).Association("TestTagAssoc").Replace(ormObj.TestTagAssoc).Error; err != nil { + if err = db.Model(&TestAssocHandlerReplaceORM{Id: ormObj.Id}).Association("TestTagAssoc").Replace(ormObj.TestTagAssoc).Error; err != nil { return nil, err } ormObj.TestTagAssoc = nil @@ -4404,7 +4404,7 @@ func DefaultStrictUpdateTestAssocHandlerClear(ctx context.Context, in *TestAssoc return nil, err } } - if err = db.Model(&ormObj).Association("TestTagAssoc").Clear().Error; err != nil { + if err = db.Model(&TestAssocHandlerClearORM{Id: ormObj.Id}).Association("TestTagAssoc").Clear().Error; err != nil { return nil, err } ormObj.TestTagAssoc = nil @@ -4761,7 +4761,7 @@ func DefaultStrictUpdateTestAssocHandlerAppend(ctx context.Context, in *TestAsso return nil, err } } - if err = db.Model(&ormObj).Association("TestTagAssoc").Append(ormObj.TestTagAssoc).Error; err != nil { + if err = db.Model(&TestAssocHandlerAppendORM{Id: ormObj.Id}).Association("TestTagAssoc").Append(ormObj.TestTagAssoc).Error; err != nil { return nil, err } ormObj.TestTagAssoc = nil diff --git a/example/user/user.pb.gorm.go b/example/user/user.pb.gorm.go index 5e1ca3cd..0edee7d2 100644 --- a/example/user/user.pb.gorm.go +++ b/example/user/user.pb.gorm.go @@ -1061,11 +1061,11 @@ func DefaultStrictUpdateUser(ctx context.Context, in *User, db *gorm1.DB) (*User if err = db.Where(filterEmails).Delete(EmailORM{}).Error; err != nil { return nil, err } - if err = db.Model(&ormObj).Association("Friends").Replace(ormObj.Friends).Error; err != nil { + if err = db.Model(&UserORM{Id: ormObj.Id}).Association("Friends").Replace(ormObj.Friends).Error; err != nil { return nil, err } ormObj.Friends = nil - if err = db.Model(&ormObj).Association("Languages").Replace(ormObj.Languages).Error; err != nil { + if err = db.Model(&UserORM{Id: ormObj.Id}).Association("Languages").Replace(ormObj.Languages).Error; err != nil { return nil, err } ormObj.Languages = nil diff --git a/plugin/handlergen.go b/plugin/handlergen.go index a088df4b..7ea25372 100644 --- a/plugin/handlergen.go +++ b/plugin/handlergen.go @@ -796,43 +796,8 @@ func (p *OrmPlugin) handleChildAssociationsByName(message *generator.Descriptor, return } - if field.GetHasMany() != nil || field.GetHasOne() != nil || field.GetManyToMany() != nil { - var assocHandler string - switch { - case field.GetHasMany() != nil: - switch { - case field.GetHasMany().GetClear(): - assocHandler = "Clear" - case field.GetHasMany().GetAppend(): - assocHandler = "Append" - case field.GetHasMany().GetReplace(): - assocHandler = "Replace" - default: - assocHandler = "Remove" - } - case field.GetHasOne() != nil: - switch { - case field.GetHasOne().GetClear(): - assocHandler = "Clear" - case field.GetHasOne().GetAppend(): - assocHandler = "Append" - case field.GetHasOne().GetReplace(): - assocHandler = "Replace" - default: - assocHandler = "Remove" - } - case field.GetManyToMany() != nil: - switch { - case field.GetManyToMany().GetClear(): - assocHandler = "Clear" - case field.GetManyToMany().GetAppend(): - assocHandler = "Append" - case field.GetManyToMany().GetReplace(): - assocHandler = "Replace" - default: - assocHandler = "Replace" - } - } + if getHasOneMany(field) { + assocHandler := getAssociationHandlerType(field) if assocHandler == "Remove" { p.removeChildAssociationsByName(message, fieldName) @@ -844,13 +809,57 @@ func (p *OrmPlugin) handleChildAssociationsByName(message *generator.Descriptor, action = fmt.Sprintf("%s()", assocHandler) } - p.P(`if err = db.Model(&ormObj).Association("`, fieldName, `").`, action, `.Error; err != nil {`) + k, _ := p.findPrimaryKey(ormable) + p.P(`if err = db.Model(&`, ormable.Name, `{`, k, ` : ormObj.`, k, `}).Association("`, fieldName, `").`, action, `.Error; err != nil {`) p.P(`return nil, err`) p.P(`}`) p.P(`ormObj.`, fieldName, ` = nil`) } } +func getHasOneMany(field *Field) bool { + return field.GetHasMany() != nil || field.GetHasOne() != nil || field.GetManyToMany() != nil +} + +func getAssociationHandlerType(field *Field) (assocHandler string) { + switch { + case field.GetHasMany() != nil: + switch { + case field.GetHasMany().GetClear(): + assocHandler = "Clear" + case field.GetHasMany().GetAppend(): + assocHandler = "Append" + case field.GetHasMany().GetReplace(): + assocHandler = "Replace" + default: + assocHandler = "Remove" + } + case field.GetHasOne() != nil: + switch { + case field.GetHasOne().GetClear(): + assocHandler = "Clear" + case field.GetHasOne().GetAppend(): + assocHandler = "Append" + case field.GetHasOne().GetReplace(): + assocHandler = "Replace" + default: + assocHandler = "Remove" + } + case field.GetManyToMany() != nil: + switch { + case field.GetManyToMany().GetClear(): + assocHandler = "Clear" + case field.GetManyToMany().GetAppend(): + assocHandler = "Append" + case field.GetManyToMany().GetReplace(): + assocHandler = "Replace" + default: + assocHandler = "Replace" + } + } + return +} + func (p *OrmPlugin) removeChildAssociationsByName(message *generator.Descriptor, fieldName string) { ormable := p.getOrmable(p.TypeName(message)) field := ormable.Fields[fieldName]