diff --git a/controller/questionnaire.go b/controller/questionnaire.go index f36d2bfc..fe15a376 100644 --- a/controller/questionnaire.go +++ b/controller/questionnaire.go @@ -171,6 +171,12 @@ func (q Questionnaire) PostQuestionnaire(c echo.Context, userID string, params o return err } + Jq.PushReminder(questionnaireID, params.ResponseDueDateTime) + if err != nil { + c.Logger().Errorf("failed to push reminder: %+v", err) + return err + } + return nil }) if err != nil { @@ -358,6 +364,17 @@ func (q Questionnaire) EditQuestionnaire(c echo.Context, questionnaireID int, pa return err } + err = Jq.DeleteReminder(questionnaireID) + if err != nil { + c.Logger().Errorf("failed to delete reminder: %+v", err) + return err + } + err = Jq.PushReminder(questionnaireID, params.ResponseDueDateTime) + if err != nil { + c.Logger().Errorf("failed to push reminder: %+v", err) + return err + } + return nil }) if err != nil { @@ -483,6 +500,12 @@ func (q Questionnaire) DeleteQuestionnaire(c echo.Context, questionnaireID int) return err } + err = Jq.DeleteReminder(questionnaireID) + if err != nil { + c.Logger().Errorf("failed to delete reminder: %+v", err) + return err + } + return nil }) if err != nil { @@ -498,12 +521,48 @@ func (q Questionnaire) DeleteQuestionnaire(c echo.Context, questionnaireID int) } func (q Questionnaire) GetQuestionnaireMyRemindStatus(c echo.Context, questionnaireID int) (bool, error) { - // todo: check remind status - return false, nil + status, err := Jq.CheckRemindStatus(questionnaireID) + if err != nil { + c.Logger().Errorf("failed to check remind status: %+v", err) + return false, echo.NewHTTPError(http.StatusInternalServerError, "failed to check remind status") + } + + return status, nil } -func (q Questionnaire) EditQuestionnaireMyRemindStatus(c echo.Context, questionnaireID int) error { - // todo: edit remind status +func (q Questionnaire) EditQuestionnaireMyRemindStatus(c echo.Context, questionnaireID int, isRemindEnabled bool) error { + if isRemindEnabled { + status, err := Jq.CheckRemindStatus(questionnaireID) + if err != nil { + c.Logger().Errorf("failed to check remind status: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, "failed to check remind status") + } + if status { + return nil + } + + questionnaire, _, _, _, _, _, err := q.GetQuestionnaireInfo(c.Request().Context(), questionnaireID) + if err != nil { + if errors.Is(err, model.ErrRecordNotFound) { + c.Logger().Info("questionnaire not found") + return echo.NewHTTPError(http.StatusNotFound, "questionnaire not found") + } + c.Logger().Errorf("failed to get questionnaire: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, "failed to get questionnaire") + } + + err = Jq.PushReminder(questionnaireID, &questionnaire.ResTimeLimit.Time) + if err != nil { + c.Logger().Errorf("failed to push reminder: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, "failed to push reminder") + } + } else { + err := Jq.DeleteReminder(questionnaireID) + if err != nil { + c.Logger().Errorf("failed to delete reminder: %+v", err) + return echo.NewHTTPError(http.StatusInternalServerError, "failed to delete reminder") + } + } return nil } @@ -705,3 +764,60 @@ https://anke-to.trap.jp/responses/new/%d`, questionnaireID, ) } + +func createReminderMessage(questionnaireID int, title string, description string, administrators []string, resTimeLimit time.Time, targets []string, leftTimeText string) string { + resTimeLimitText := resTimeLimit.Local().Format("2006/01/02 15:04") + targetsMentionText := "@" + strings.Join(targets, " @") + + return fmt.Sprintf( + `### アンケート『[%s](https://anke-to.trap.jp/questionnaires/%d)』の回答期限が迫っています! +==残り%sです!== +#### 管理者 +%s +#### 説明 +%s +#### 回答期限 +%s +#### 対象者 +%s +#### 回答リンク +https://anke-to.trap.jp/responses/new/%d +`, + title, + questionnaireID, + leftTimeText, + strings.Join(administrators, ","), + description, + resTimeLimitText, + targetsMentionText, + questionnaireID, + ) +} + +func (q Questionnaire) GetQuestionnaireResult(ctx echo.Context, questionnaireID int, userID string) (openapi.Result, error) { + res := openapi.Result{} + + params := openapi.GetQuestionnaireResponsesParams{} + responses, err := q.GetQuestionnaireResponses(ctx, questionnaireID, params, userID) + if err != nil { + if errors.Is(echo.ErrNotFound, err) { + return openapi.Result{}, err + } + ctx.Logger().Errorf("failed to get questionnaire responses: %+v", err) + return openapi.Result{}, echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get questionnaire responses: %w", err)) + } + + for _, response := range responses { + tmp := openapi.ResultItem{ + Body: response.Body, + IsDraft: response.IsDraft, + ModifiedAt: response.ModifiedAt, + QuestionnaireId: response.QuestionnaireId, + ResponseId: response.ResponseId, + SubmittedAt: response.SubmittedAt, + } + res = append(res, tmp) + } + + return res, nil +} diff --git a/controller/reminder.go b/controller/reminder.go new file mode 100644 index 00000000..3da0139f --- /dev/null +++ b/controller/reminder.go @@ -0,0 +1,158 @@ +package controller + +import ( + "context" + "slices" + "sort" + "sync" + "time" + + "github.com/traPtitech/anke-to/model" + "github.com/traPtitech/anke-to/traq" + "golang.org/x/sync/semaphore" +) + +type Job struct { + Timestamp time.Time + QuestionnaireID int + Action func() +} + +type JobQueue struct { + jobs []*Job + mu sync.Mutex +} + +var ( + sem = semaphore.NewWeighted(1) + Jq = &JobQueue{} + Wg = &sync.WaitGroup{} + reminderTimingMinutes = []int{5, 30, 60, 1440, 10080} + reminderTimingStrings = []string{"5分", "30分", "1時間", "1日", "1週間"} +) + +func (jq *JobQueue) Push(job *Job) { + jq.mu.Lock() + defer jq.mu.Unlock() + jq.jobs = append(jq.jobs, job) + sort.Slice(jq.jobs, func(i, j int) bool { + return jq.jobs[i].Timestamp.Before(jq.jobs[j].Timestamp) + }) +} + +func (jq *JobQueue) Pop() *Job { + jq.mu.Lock() + defer jq.mu.Unlock() + if len(jq.jobs) == 0 { + return nil + } + job := jq.jobs[0] + jq.jobs = jq.jobs[1:] + return job +} + +func (jq *JobQueue) PushReminder(questionnaireID int, limit *time.Time) error { + + for i, timing := range reminderTimingMinutes { + remindTimeStamp := limit.Add(-time.Duration(timing) * time.Minute) + if remindTimeStamp.Before(time.Now()) { + Jq.Push(&Job{ + Timestamp: remindTimeStamp, + QuestionnaireID: questionnaireID, + Action: func() { + reminderAction(questionnaireID, reminderTimingStrings[i]) + }, + }) + } + } + + return nil +} + +func (jq *JobQueue) DeleteReminder(questionnaireID int) error { + jq.mu.Lock() + defer jq.mu.Unlock() + if len(jq.jobs) == 1 && jq.jobs[0].QuestionnaireID == questionnaireID { + jq.jobs = []*Job{} + } + for i, job := range jq.jobs { + if job.QuestionnaireID == questionnaireID { + jq.jobs = append(jq.jobs[:i], jq.jobs[i+1:]...) + } + } + + return nil +} + +func (jq *JobQueue) CheckRemindStatus(questionnaireID int) (bool, error) { + jq.mu.Lock() + defer jq.mu.Unlock() + for _, job := range jq.jobs { + if job.QuestionnaireID == questionnaireID { + return true, nil + } + } + return false, nil +} + +func reminderAction(questionnaireID int, leftTimeText string) error { + ctx := context.Background() + q := model.Questionnaire{} + questionnaire, _, _, administrators, _, respondants, err := q.GetQuestionnaireInfo(ctx, questionnaireID) + if err != nil { + return err + } + + var reminderTargets []string + for _, target := range questionnaire.Targets { + if target.IsCanceled { + continue + } + if slices.Contains(respondants, target.UserTraqid) { + continue + } + reminderTargets = append(reminderTargets, target.UserTraqid) + } + + reminderMessage := createReminderMessage(questionnaireID, questionnaire.Title, questionnaire.Description, administrators, questionnaire.ResTimeLimit.Time, reminderTargets, leftTimeText) + wh := traq.NewWebhook() + err = wh.PostMessage(reminderMessage) + if err != nil { + return err + } + + return nil +} + +func ReminderWorker() { + for { + job := Jq.Pop() + if job == nil { + time.Sleep(1 * time.Minute) + continue + } + + if time.Until(job.Timestamp) > 0 { + time.Sleep(time.Until(job.Timestamp)) + } + + Wg.Add(1) + go func() { + defer Wg.Done() + job.Action() + }() + } +} + +func ReminderInit() { + questionnaires, err := model.NewQuestionnaire().GetQuestionnairesInfoForReminder(context.Background()) + if err != nil { + panic(err) + } + for _, questionnaire := range questionnaires { + err := Jq.PushReminder(questionnaire.ID, &questionnaire.ResTimeLimit.Time) + if err != nil { + panic(err) + } + } +} \ No newline at end of file diff --git a/docs/db_schema.md b/docs/db_schema.md index fee0ed48..aafc17f7 100644 --- a/docs/db_schema.md +++ b/docs/db_schema.md @@ -108,3 +108,4 @@ | ---------------- | -------- | ---- | --- | ------- | ----- | -------- | | questionnaire_id | int(11) | NO | PRI | _NULL_ | | user_traqid | char(32) | NO | PRI | _NULL_ | +| is_canceled | boolean | NO | | false | | アンケートの対象者がキャンセルしたかどうか | diff --git a/go.mod b/go.mod index 1ff36d4d..a17726c7 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.25.0 // indirect golang.org/x/oauth2 v0.10.0 - golang.org/x/sync v0.7.0 + golang.org/x/sync v0.8.0 golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/tools v0.21.0 // indirect diff --git a/go.sum b/go.sum index 2c369dc4..0c26be87 100644 --- a/go.sum +++ b/go.sum @@ -575,6 +575,8 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/handler/questionnaire.go b/handler/questionnaire.go index 7e7c0741..d30e9bb7 100644 --- a/handler/questionnaire.go +++ b/handler/questionnaire.go @@ -118,7 +118,7 @@ func (h Handler) GetQuestionnaireMyRemindStatus(ctx echo.Context, questionnaireI status, err := q.GetQuestionnaireMyRemindStatus(ctx, questionnaireID) if err != nil { ctx.Logger().Errorf("failed to get questionnaire my remind status: %+v", err) - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get questionnaire my remind status: %w", err)) + return err } res.IsRemindEnabled = status @@ -127,11 +127,17 @@ func (h Handler) GetQuestionnaireMyRemindStatus(ctx echo.Context, questionnaireI // (PATCH /questionnaires/{questionnaireID}/myRemindStatus) func (h Handler) EditQuestionnaireMyRemindStatus(ctx echo.Context, questionnaireID openapi.QuestionnaireIDInPath) error { + params := openapi.EditQuestionnaireMyRemindStatusJSONRequestBody{} + if err := ctx.Bind(¶ms); err != nil { + ctx.Logger().Errorf("failed to bind request body: %+v", err) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to bind request body: %w", err)) + } + q := controller.NewQuestionnaire() - err := q.EditQuestionnaireMyRemindStatus(ctx, questionnaireID) + err := q.EditQuestionnaireMyRemindStatus(ctx, questionnaireID, params.IsRemindEnabled) if err != nil { ctx.Logger().Errorf("failed to edit questionnaire my remind status: %+v", err) - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to edit questionnaire my remind status: %w", err)) + return err } return ctx.NoContent(200) } diff --git a/main.go b/main.go index adec2505..3778b60b 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" oapiMiddleware "github.com/oapi-codegen/echo-middleware" + "github.com/traPtitech/anke-to/controller" "github.com/traPtitech/anke-to/handler" "github.com/traPtitech/anke-to/model" "github.com/traPtitech/anke-to/openapi" @@ -57,32 +58,51 @@ func main() { panic("no PORT") } - e := echo.New() - swagger, err := openapi.GetSwagger() - if err != nil { - panic(err) - } - e.Use(oapiMiddleware.OapiRequestValidator(swagger)) - e.Use(handler.SetUserIDMiddleware) - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) - - mws := NewMiddlewareSwitcher() - mws.AddGroupConfig("", handler.TraPMemberAuthenticate) - - mws.AddRouteConfig("/questionnaires", http.MethodGet, handler.TrapRateLimitMiddlewareFunc()) - mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodGet, handler.QuestionnaireReadAuthenticate) - mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, handler.QuestionnaireAdministratorAuthenticate) - mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, handler.QuestionnaireAdministratorAuthenticate) - - mws.AddRouteConfig("/responses/:responseID", http.MethodGet, handler.ResponseReadAuthenticate) - mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, handler.RespondentAuthenticate) - mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, handler.RespondentAuthenticate) - - openapi.RegisterHandlers(e, handler.Handler{}) - - e.Use(mws.ApplyMiddlewares) - e.Logger.Fatal(e.Start(port)) + controller.Wg.Add(1) + go func() { + e := echo.New() + swagger, err := openapi.GetSwagger() + if err != nil { + panic(err) + } + e.Use(oapiMiddleware.OapiRequestValidator(swagger)) + e.Use(handler.SetUserIDMiddleware) + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) + + mws := NewMiddlewareSwitcher() + mws.AddGroupConfig("", handler.TraPMemberAuthenticate) + + mws.AddRouteConfig("/questionnaires", http.MethodGet, handler.TrapRateLimitMiddlewareFunc()) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodGet, handler.QuestionnaireReadAuthenticate) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, handler.QuestionnaireAdministratorAuthenticate) + mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, handler.QuestionnaireAdministratorAuthenticate) + + mws.AddRouteConfig("/responses/:responseID", http.MethodGet, handler.ResponseReadAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, handler.RespondentAuthenticate) + mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, handler.RespondentAuthenticate) + + openapi.RegisterHandlers(e, handler.Handler{}) + + e.Use(mws.ApplyMiddlewares) + e.Logger.Fatal(e.Start(port)) + + controller.Wg.Done() + }() + + controller.Wg.Add(1) + go func () { + controller.ReminderInit() + controller.Wg.Done() + }() + + controller.Wg.Add(1) + go func() { + controller.ReminderWorker() + controller.Wg.Done() + }() + + controller.Wg.Wait() // SetRouting(port) } diff --git a/model/questionnaires.go b/model/questionnaires.go index c9f04c87..225ac674 100644 --- a/model/questionnaires.go +++ b/model/questionnaires.go @@ -21,4 +21,5 @@ type IQuestionnaire interface { GetQuestionnaireLimitByResponseID(ctx context.Context, responseID int) (null.Time, error) GetResponseReadPrivilegeInfoByResponseID(ctx context.Context, userID string, responseID int) (*ResponseReadPrivilegeInfo, error) GetResponseReadPrivilegeInfoByQuestionnaireID(ctx context.Context, userID string, questionnaireID int) (*ResponseReadPrivilegeInfo, error) + GetQuestionnairesInfoForReminder(ctx context.Context) ([]Questionnaires, error) } diff --git a/model/questionnaires_impl.go b/model/questionnaires_impl.go index 25e20cb9..e194e049 100755 --- a/model/questionnaires_impl.go +++ b/model/questionnaires_impl.go @@ -296,6 +296,7 @@ func (*Questionnaire) GetQuestionnaireInfo(ctx context.Context, questionnaireID err = db. Where("questionnaires.id = ?", questionnaireID). + Preload("Targets"). First(&questionnaire).Error if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil, nil, nil, nil, nil, ErrRecordNotFound @@ -378,6 +379,24 @@ func (*Questionnaire) GetTargettedQuestionnaires(ctx context.Context, userID str return questionnaires, nil } +// GetQuestionnairesInfoForReminder 回答期限が7日以内のアンケートの詳細情報の取得 +func (*Questionnaire) GetQuestionnairesInfoForReminder(ctx context.Context) ([]Questionnaires, error) { + db, err := getTx(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tx: %w", err) + } + + questionnaires := []Questionnaires{} + err = db. + Where("res_time_limit > ? AND res_time_limit < ?", time.Now(), time.Now().AddDate(0, 0, 7)). + Find(&questionnaires).Error + if err != nil { + return nil, fmt.Errorf("failed to get the questionnaires: %w", err) + } + + return questionnaires, nil +} + // GetQuestionnaireLimit アンケートの回答期限の取得 func (*Questionnaire) GetQuestionnaireLimit(ctx context.Context, questionnaireID int) (null.Time, error) { db, err := getTx(ctx) diff --git a/model/targets.go b/model/targets.go index a1bb5506..e6d4cc17 100644 --- a/model/targets.go +++ b/model/targets.go @@ -10,4 +10,5 @@ type ITarget interface { DeleteTargets(ctx context.Context, questionnaireID int) error GetTargets(ctx context.Context, questionnaireIDs []int) ([]Targets, error) IsTargetingMe(ctx context.Context, quesionnairID int, userID string) (bool, error) + CancelTargets(ctx context.Context, questionnaireID int, targets []string) error } diff --git a/model/targets_impl.go b/model/targets_impl.go index 6e046217..65196014 100644 --- a/model/targets_impl.go +++ b/model/targets_impl.go @@ -17,6 +17,7 @@ func NewTarget() *Target { type Targets struct { QuestionnaireID int `gorm:"type:int(11) AUTO_INCREMENT;not null;primaryKey"` UserTraqid string `gorm:"type:varchar(32);size:32;not null;primaryKey"` + IsCanceled bool `gorm:"type:tinyint(1);not null;default:0"` } // InsertTargets アンケートの対象を追加 @@ -35,6 +36,7 @@ func (*Target) InsertTargets(ctx context.Context, questionnaireID int, targets [ dbTargets = append(dbTargets, Targets{ QuestionnaireID: questionnaireID, UserTraqid: target, + IsCanceled: false, }) } @@ -101,3 +103,21 @@ func (*Target) IsTargetingMe(ctx context.Context, questionnairID int, userID str } return false, nil } + +// CancelTargets アンケートの対象をキャンセル(削除しない) +func (*Target) CancelTargets(ctx context.Context, questionnaireID int, targets []string) error { + db, err := getTx(ctx) + if err != nil { + return fmt.Errorf("failed to get transaction: %w", err) + } + + err = db. + Model(&Targets{}). + Where("questionnaire_id = ? AND user_traqid IN (?)", questionnaireID, targets). + Update("is_canceled", true).Error + if err != nil { + return fmt.Errorf("failed to cancel targets: %w", err) + } + + return nil +} \ No newline at end of file diff --git a/model/targets_test.go b/model/targets_test.go index 79adf7c6..2cd8f51e 100644 --- a/model/targets_test.go +++ b/model/targets_test.go @@ -376,3 +376,119 @@ func TestIsTargetingMe(t *testing.T) { assertion.Equal(testCase.expect.isTargeted, isTargeted, testCase.description, "isTargeted") } } + +func TestCancelTargets(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type test struct { + description string + beforeValidTargets []string + beforeInvalidTargets []string + afterValidTargets []string + afterInvalidTargets []string + argCancelTargets []string + isErr bool + err error + } + + testCases := []test{ + { + description: "キャンセルするtargetが1人でエラーなし", + beforeValidTargets: []string{"a"}, + beforeInvalidTargets: []string{}, + afterValidTargets: []string{}, + afterInvalidTargets: []string{"a"}, + argCancelTargets: []string{"a"}, + }, + { + description: "キャンセルするtargetが複数でエラーなし", + beforeValidTargets: []string{"a", "b"}, + beforeInvalidTargets: []string{}, + afterValidTargets: []string{}, + afterInvalidTargets: []string{"a", "b"}, + argCancelTargets: []string{"a", "b"}, + }, + { + description: "キャンセルするtargetがないときエラーなし", + beforeValidTargets: []string{"a"}, + beforeInvalidTargets: []string{}, + afterValidTargets: []string{"a"}, + afterInvalidTargets: []string{}, + argCancelTargets: []string{}, + }, + { + description: "キャンセルするtargetが見つからないときエラー", + beforeValidTargets: []string{"a"}, + beforeInvalidTargets: []string{}, + afterValidTargets: []string{"a"}, + afterInvalidTargets: []string{}, + argCancelTargets: []string{"b"}, + isErr: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.description, func(t *testing.T) { + targets := make([]Targets, 0, len(testCase.beforeValidTargets)+len(testCase.beforeInvalidTargets)) + for _, target := range testCase.beforeValidTargets { + targets = append(targets, Targets{ + UserTraqid: target, + IsCanceled: false, + }) + } + for _, target := range testCase.beforeInvalidTargets { + targets = append(targets, Targets{ + UserTraqid: target, + IsCanceled: true, + }) + } + questionnaire := Questionnaires{ + Targets: targets, + } + err := db. + Session(&gorm.Session{}). + Create(&questionnaire).Error + if err != nil { + t.Errorf("failed to create questionnaire: %v", err) + } + + err = targetImpl.CancelTargets(ctx, questionnaire.ID, testCase.argCancelTargets) + if err != nil { + if !testCase.isErr { + t.Errorf("unexpected error: %v", err) + } else if !errors.Is(err, testCase.err) { + t.Errorf("invalid error: expected: %+v, actual: %+v", testCase.err, err) + } + return + } + + afterTargets := make([]Targets, 0, len(testCase.afterValidTargets)+len(testCase.afterInvalidTargets)) + for _, afterTarget := range testCase.afterInvalidTargets { + afterTargets = append(afterTargets, Targets{ + UserTraqid: afterTarget, + IsCanceled: false, + }) + } + for _, afterTarget := range testCase.afterValidTargets { + afterTargets = append(afterTargets, Targets{ + UserTraqid: afterTarget, + IsCanceled: true, + }) + } + + actualTargets := make([]Targets, 0, len(testCase.afterValidTargets)+len(testCase.afterInvalidTargets)) + err = db. + Session(&gorm.Session{}). + Model(&Targets{}). + Where("questionnaire_id = ?", questionnaire.ID). + Find(&actualTargets).Error + if err != nil { + t.Errorf("failed to get targets: %v", err) + } + + assert.ElementsMatchf(t, afterTargets, actualTargets, "targets") + }) + } +} diff --git a/model/v3.go b/model/v3.go index 1935453a..3421bab8 100644 --- a/model/v3.go +++ b/model/v3.go @@ -10,16 +10,29 @@ import ( func v3() *gormigrate.Migration { return &gormigrate.Migration{ - ID: "v3", + ID: "3", Migrate: func(tx *gorm.DB) error { - if err := tx.AutoMigrate(&Targets{}); err != nil { + if err := tx.AutoMigrate(&v3Targets{}); err != nil { return err } + if err := tx.AutoMigrate(&v3Questionnaires{}); err != nil { + return err + } return nil }, } } +type v3Targets struct { + QuestionnaireID int `gorm:"type:int(11) AUTO_INCREMENT;not null;primaryKey"` + UserTraqid string `gorm:"type:varchar(32);size:32;not null;primaryKey"` + IsCanceled bool `gorm:"type:tinyint(1);not null;default:0"` +} + +func (*v3Targets) TableName() string { + return "targets" +} + type v3Questionnaires struct { ID int `json:"questionnaireID" gorm:"type:int(11) AUTO_INCREMENT;not null;primaryKey"` Title string `json:"title" gorm:"type:char(50);size:50;not null"`