From c7648670c839539ef94a4af9d15546de34929aa4 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Thu, 8 Feb 2024 23:06:59 +0800 Subject: [PATCH 1/3] db: init memo and memo relation schema Signed-off-by: Wei Zhang --- ent/schema/memo.go | 34 ++++++++++++++++++++++++++++++++++ ent/schema/memorelation.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 ent/schema/memo.go create mode 100644 ent/schema/memorelation.go diff --git a/ent/schema/memo.go b/ent/schema/memo.go new file mode 100644 index 0000000000000..22523d4f63b3f --- /dev/null +++ b/ent/schema/memo.go @@ -0,0 +1,34 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" +) + +// Memo holds the schema definition for the Memo entity. +type Memo struct { + ent.Schema +} + +// Fields of the Memo. +func (Memo) Fields() []ent.Field { + return []ent.Field{ + field.Int("id").Positive(), + field.String("resource_name").MaxLen(256).NotEmpty().Unique(), + field.Int("creator_id").Positive(), + field.Time("created_ts"), + field.Time("updated_ts"), + field.String("row_status").MaxLen(256).NotEmpty(), + field.Text("content").Default(""), + field.String("visibility").MaxLen(256).NotEmpty(), + } +} + +// Edges of the Memo. +func (Memo) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("related_memo", Memo.Type). + Through("memo_relation", MemoRelation.Type), + } +} diff --git a/ent/schema/memorelation.go b/ent/schema/memorelation.go new file mode 100644 index 0000000000000..0dbb7fd7907a9 --- /dev/null +++ b/ent/schema/memorelation.go @@ -0,0 +1,35 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" +) + +// MemoRelation holds the schema definition for the MemoRelation entity. +type MemoRelation struct { + ent.Schema +} + +// Fields of the MemoRelation. +func (MemoRelation) Fields() []ent.Field { + return []ent.Field{ + field.String("type"), + field.Int("memo_id"), + field.Int("related_memo_id"), + } +} + +// Edges of the MemoRelation. +func (MemoRelation) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("memo", Memo.Type). + Required(). + Unique(). + Field("memo_id"), + edge.To("related_memo", Memo.Type). + Required(). + Unique(). + Field("related_memo_id"), + } +} From 2d4ae64791340180034b7d817a54a89e6484daeb Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Thu, 8 Feb 2024 23:08:10 +0800 Subject: [PATCH 2/3] gen/db: go generate ent for memo and memo relation Signed-off-by: Wei Zhang --- ent/client.go | 548 ++++++++++++ ent/ent.go | 610 +++++++++++++ ent/enttest/enttest.go | 84 ++ ent/generate.go | 3 + ent/hook/hook.go | 211 +++++ ent/memo.go | 214 +++++ ent/memo/memo.go | 172 ++++ ent/memo/where.go | 532 +++++++++++ ent/memo_create.go | 380 ++++++++ ent/memo_delete.go | 88 ++ ent/memo_query.go | 709 +++++++++++++++ ent/memo_update.go | 815 +++++++++++++++++ ent/memorelation.go | 176 ++++ ent/memorelation/memorelation.go | 110 +++ ent/memorelation/where.go | 235 +++++ ent/memorelation_create.go | 252 ++++++ ent/memorelation_delete.go | 88 ++ ent/memorelation_query.go | 679 ++++++++++++++ ent/memorelation_update.go | 454 ++++++++++ ent/migrate/migrate.go | 64 ++ ent/migrate/schema.go | 72 ++ ent/mutation.go | 1435 ++++++++++++++++++++++++++++++ ent/predicate/predicate.go | 13 + ent/runtime.go | 82 ++ ent/runtime/runtime.go | 10 + ent/tx.go | 213 +++++ 26 files changed, 8249 insertions(+) create mode 100644 ent/client.go create mode 100644 ent/ent.go create mode 100644 ent/enttest/enttest.go create mode 100644 ent/generate.go create mode 100644 ent/hook/hook.go create mode 100644 ent/memo.go create mode 100644 ent/memo/memo.go create mode 100644 ent/memo/where.go create mode 100644 ent/memo_create.go create mode 100644 ent/memo_delete.go create mode 100644 ent/memo_query.go create mode 100644 ent/memo_update.go create mode 100644 ent/memorelation.go create mode 100644 ent/memorelation/memorelation.go create mode 100644 ent/memorelation/where.go create mode 100644 ent/memorelation_create.go create mode 100644 ent/memorelation_delete.go create mode 100644 ent/memorelation_query.go create mode 100644 ent/memorelation_update.go create mode 100644 ent/migrate/migrate.go create mode 100644 ent/migrate/schema.go create mode 100644 ent/mutation.go create mode 100644 ent/predicate/predicate.go create mode 100644 ent/runtime.go create mode 100644 ent/runtime/runtime.go create mode 100644 ent/tx.go diff --git a/ent/client.go b/ent/client.go new file mode 100644 index 0000000000000..6f8b32499ef7a --- /dev/null +++ b/ent/client.go @@ -0,0 +1,548 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "log" + "reflect" + + "github.com/usememos/memos/ent/migrate" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" +) + +// Client is the client that holds all ent builders. +type Client struct { + config + // Schema is the client for creating, migrating and dropping schema. + Schema *migrate.Schema + // Memo is the client for interacting with the Memo builders. + Memo *MemoClient + // MemoRelation is the client for interacting with the MemoRelation builders. + MemoRelation *MemoRelationClient +} + +// NewClient creates a new client configured with the given options. +func NewClient(opts ...Option) *Client { + client := &Client{config: newConfig(opts...)} + client.init() + return client +} + +func (c *Client) init() { + c.Schema = migrate.NewSchema(c.driver) + c.Memo = NewMemoClient(c.config) + c.MemoRelation = NewMemoRelationClient(c.config) +} + +type ( + // config is the configuration for the client and its builder. + config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...any) + // hooks to execute on mutations. + hooks *hooks + // interceptors to execute on queries. + inters *inters + } + // Option function to configure the client. + Option func(*config) +) + +// newConfig creates a new config for the client. +func newConfig(opts ...Option) config { + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} + cfg.options(opts...) + return cfg +} + +// options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...any)) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } +} + +// Open opens a database/sql.DB specified by the driver name and +// the data source name, and returns a new client attached to it. +// Optional parameters can be added for configuring the client. +func Open(driverName, dataSourceName string, options ...Option) (*Client, error) { + switch driverName { + case dialect.MySQL, dialect.Postgres, dialect.SQLite: + drv, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return NewClient(append(options, Driver(drv))...), nil + default: + return nil, fmt.Errorf("unsupported driver: %q", driverName) + } +} + +// ErrTxStarted is returned when trying to start a new transaction from a transactional client. +var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction") + +// Tx returns a new transactional client. The provided context +// is used until the transaction is committed or rolled back. +func (c *Client) Tx(ctx context.Context) (*Tx, error) { + if _, ok := c.driver.(*txDriver); ok { + return nil, ErrTxStarted + } + tx, err := newTx(ctx, c.driver) + if err != nil { + return nil, fmt.Errorf("ent: starting a transaction: %w", err) + } + cfg := c.config + cfg.driver = tx + return &Tx{ + ctx: ctx, + config: cfg, + Memo: NewMemoClient(cfg), + MemoRelation: NewMemoRelationClient(cfg), + }, nil +} + +// BeginTx returns a transactional client with specified options. +func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + if _, ok := c.driver.(*txDriver); ok { + return nil, errors.New("ent: cannot start a transaction within a transaction") + } + tx, err := c.driver.(interface { + BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) + }).BeginTx(ctx, opts) + if err != nil { + return nil, fmt.Errorf("ent: starting a transaction: %w", err) + } + cfg := c.config + cfg.driver = &txDriver{tx: tx, drv: c.driver} + return &Tx{ + ctx: ctx, + config: cfg, + Memo: NewMemoClient(cfg), + MemoRelation: NewMemoRelationClient(cfg), + }, nil +} + +// Debug returns a new debug-client. It's used to get verbose logging on specific operations. +// +// client.Debug(). +// Memo. +// Query(). +// Count(ctx) +func (c *Client) Debug() *Client { + if c.debug { + return c + } + cfg := c.config + cfg.driver = dialect.Debug(c.driver, c.log) + client := &Client{config: cfg} + client.init() + return client +} + +// Close closes the database connection and prevents new queries from starting. +func (c *Client) Close() error { + return c.driver.Close() +} + +// Use adds the mutation hooks to all the entity clients. +// In order to add hooks to a specific client, call: `client.Node.Use(...)`. +func (c *Client) Use(hooks ...Hook) { + c.Memo.Use(hooks...) + c.MemoRelation.Use(hooks...) +} + +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + c.Memo.Intercept(interceptors...) + c.MemoRelation.Intercept(interceptors...) +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *MemoMutation: + return c.Memo.mutate(ctx, m) + case *MemoRelationMutation: + return c.MemoRelation.mutate(ctx, m) + default: + return nil, fmt.Errorf("ent: unknown mutation type %T", m) + } +} + +// MemoClient is a client for the Memo schema. +type MemoClient struct { + config +} + +// NewMemoClient returns a client for the Memo from the given config. +func NewMemoClient(c config) *MemoClient { + return &MemoClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `memo.Hooks(f(g(h())))`. +func (c *MemoClient) Use(hooks ...Hook) { + c.hooks.Memo = append(c.hooks.Memo, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `memo.Intercept(f(g(h())))`. +func (c *MemoClient) Intercept(interceptors ...Interceptor) { + c.inters.Memo = append(c.inters.Memo, interceptors...) +} + +// Create returns a builder for creating a Memo entity. +func (c *MemoClient) Create() *MemoCreate { + mutation := newMemoMutation(c.config, OpCreate) + return &MemoCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of Memo entities. +func (c *MemoClient) CreateBulk(builders ...*MemoCreate) *MemoCreateBulk { + return &MemoCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MemoClient) MapCreateBulk(slice any, setFunc func(*MemoCreate, int)) *MemoCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MemoCreateBulk{err: fmt.Errorf("calling to MemoClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MemoCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MemoCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for Memo. +func (c *MemoClient) Update() *MemoUpdate { + mutation := newMemoMutation(c.config, OpUpdate) + return &MemoUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *MemoClient) UpdateOne(m *Memo) *MemoUpdateOne { + mutation := newMemoMutation(c.config, OpUpdateOne, withMemo(m)) + return &MemoUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *MemoClient) UpdateOneID(id int) *MemoUpdateOne { + mutation := newMemoMutation(c.config, OpUpdateOne, withMemoID(id)) + return &MemoUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for Memo. +func (c *MemoClient) Delete() *MemoDelete { + mutation := newMemoMutation(c.config, OpDelete) + return &MemoDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *MemoClient) DeleteOne(m *Memo) *MemoDeleteOne { + return c.DeleteOneID(m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *MemoClient) DeleteOneID(id int) *MemoDeleteOne { + builder := c.Delete().Where(memo.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &MemoDeleteOne{builder} +} + +// Query returns a query builder for Memo. +func (c *MemoClient) Query() *MemoQuery { + return &MemoQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeMemo}, + inters: c.Interceptors(), + } +} + +// Get returns a Memo entity by its id. +func (c *MemoClient) Get(ctx context.Context, id int) (*Memo, error) { + return c.Query().Where(memo.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *MemoClient) GetX(ctx context.Context, id int) *Memo { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryRelatedMemo queries the related_memo edge of a Memo. +func (c *MemoClient) QueryRelatedMemo(m *Memo) *MemoQuery { + query := (&MemoClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := m.ID + step := sqlgraph.NewStep( + sqlgraph.From(memo.Table, memo.FieldID, id), + sqlgraph.To(memo.Table, memo.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, memo.RelatedMemoTable, memo.RelatedMemoPrimaryKey...), + ) + fromV = sqlgraph.Neighbors(m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryMemoRelation queries the memo_relation edge of a Memo. +func (c *MemoClient) QueryMemoRelation(m *Memo) *MemoRelationQuery { + query := (&MemoRelationClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := m.ID + step := sqlgraph.NewStep( + sqlgraph.From(memo.Table, memo.FieldID, id), + sqlgraph.To(memorelation.Table, memorelation.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, memo.MemoRelationTable, memo.MemoRelationColumn), + ) + fromV = sqlgraph.Neighbors(m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *MemoClient) Hooks() []Hook { + return c.hooks.Memo +} + +// Interceptors returns the client interceptors. +func (c *MemoClient) Interceptors() []Interceptor { + return c.inters.Memo +} + +func (c *MemoClient) mutate(ctx context.Context, m *MemoMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MemoCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MemoUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MemoUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MemoDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Memo mutation op: %q", m.Op()) + } +} + +// MemoRelationClient is a client for the MemoRelation schema. +type MemoRelationClient struct { + config +} + +// NewMemoRelationClient returns a client for the MemoRelation from the given config. +func NewMemoRelationClient(c config) *MemoRelationClient { + return &MemoRelationClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `memorelation.Hooks(f(g(h())))`. +func (c *MemoRelationClient) Use(hooks ...Hook) { + c.hooks.MemoRelation = append(c.hooks.MemoRelation, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `memorelation.Intercept(f(g(h())))`. +func (c *MemoRelationClient) Intercept(interceptors ...Interceptor) { + c.inters.MemoRelation = append(c.inters.MemoRelation, interceptors...) +} + +// Create returns a builder for creating a MemoRelation entity. +func (c *MemoRelationClient) Create() *MemoRelationCreate { + mutation := newMemoRelationMutation(c.config, OpCreate) + return &MemoRelationCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of MemoRelation entities. +func (c *MemoRelationClient) CreateBulk(builders ...*MemoRelationCreate) *MemoRelationCreateBulk { + return &MemoRelationCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *MemoRelationClient) MapCreateBulk(slice any, setFunc func(*MemoRelationCreate, int)) *MemoRelationCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &MemoRelationCreateBulk{err: fmt.Errorf("calling to MemoRelationClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*MemoRelationCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &MemoRelationCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for MemoRelation. +func (c *MemoRelationClient) Update() *MemoRelationUpdate { + mutation := newMemoRelationMutation(c.config, OpUpdate) + return &MemoRelationUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *MemoRelationClient) UpdateOne(mr *MemoRelation) *MemoRelationUpdateOne { + mutation := newMemoRelationMutation(c.config, OpUpdateOne, withMemoRelation(mr)) + return &MemoRelationUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *MemoRelationClient) UpdateOneID(id int) *MemoRelationUpdateOne { + mutation := newMemoRelationMutation(c.config, OpUpdateOne, withMemoRelationID(id)) + return &MemoRelationUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for MemoRelation. +func (c *MemoRelationClient) Delete() *MemoRelationDelete { + mutation := newMemoRelationMutation(c.config, OpDelete) + return &MemoRelationDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *MemoRelationClient) DeleteOne(mr *MemoRelation) *MemoRelationDeleteOne { + return c.DeleteOneID(mr.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *MemoRelationClient) DeleteOneID(id int) *MemoRelationDeleteOne { + builder := c.Delete().Where(memorelation.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &MemoRelationDeleteOne{builder} +} + +// Query returns a query builder for MemoRelation. +func (c *MemoRelationClient) Query() *MemoRelationQuery { + return &MemoRelationQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeMemoRelation}, + inters: c.Interceptors(), + } +} + +// Get returns a MemoRelation entity by its id. +func (c *MemoRelationClient) Get(ctx context.Context, id int) (*MemoRelation, error) { + return c.Query().Where(memorelation.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *MemoRelationClient) GetX(ctx context.Context, id int) *MemoRelation { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryMemo queries the memo edge of a MemoRelation. +func (c *MemoRelationClient) QueryMemo(mr *MemoRelation) *MemoQuery { + query := (&MemoClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := mr.ID + step := sqlgraph.NewStep( + sqlgraph.From(memorelation.Table, memorelation.FieldID, id), + sqlgraph.To(memo.Table, memo.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, memorelation.MemoTable, memorelation.MemoColumn), + ) + fromV = sqlgraph.Neighbors(mr.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryRelatedMemo queries the related_memo edge of a MemoRelation. +func (c *MemoRelationClient) QueryRelatedMemo(mr *MemoRelation) *MemoQuery { + query := (&MemoClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := mr.ID + step := sqlgraph.NewStep( + sqlgraph.From(memorelation.Table, memorelation.FieldID, id), + sqlgraph.To(memo.Table, memo.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, memorelation.RelatedMemoTable, memorelation.RelatedMemoColumn), + ) + fromV = sqlgraph.Neighbors(mr.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *MemoRelationClient) Hooks() []Hook { + return c.hooks.MemoRelation +} + +// Interceptors returns the client interceptors. +func (c *MemoRelationClient) Interceptors() []Interceptor { + return c.inters.MemoRelation +} + +func (c *MemoRelationClient) mutate(ctx context.Context, m *MemoRelationMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&MemoRelationCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&MemoRelationUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&MemoRelationUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&MemoRelationDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown MemoRelation mutation op: %q", m.Op()) + } +} + +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + Memo, MemoRelation []ent.Hook + } + inters struct { + Memo, MemoRelation []ent.Interceptor + } +) diff --git a/ent/ent.go b/ent/ent.go new file mode 100644 index 0000000000000..bbec5cb2f0ad2 --- /dev/null +++ b/ent/ent.go @@ -0,0 +1,610 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" +) + +// ent aliases to avoid import conflicts in user's code. +type ( + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc +) + +type clientCtxKey struct{} + +// FromContext returns a Client stored inside a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(clientCtxKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, clientCtxKey{}, c) +} + +type txCtxKey struct{} + +// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. +func TxFromContext(ctx context.Context) *Tx { + tx, _ := ctx.Value(txCtxKey{}).(*Tx) + return tx +} + +// NewTxContext returns a new context with the given Tx attached. +func NewTxContext(parent context.Context, tx *Tx) context.Context { + return context.WithValue(parent, txCtxKey{}, tx) +} + +// OrderFunc applies an ordering on the sql selector. +// Deprecated: Use Asc/Desc functions or the package builders instead. +type OrderFunc func(*sql.Selector) + +var ( + initCheck sync.Once + columnCheck sql.ColumnCheck +) + +// columnChecker checks if the column exists in the given table. +func checkColumn(table, column string) error { + initCheck.Do(func() { + columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + memo.Table: memo.ValidColumn, + memorelation.Table: memorelation.ValidColumn, + }) + }) + return columnCheck(table, column) +} + +// Asc applies the given fields in ASC order. +func Asc(fields ...string) func(*sql.Selector) { + return func(s *sql.Selector) { + for _, f := range fields { + if err := checkColumn(s.TableName(), f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + } + s.OrderBy(sql.Asc(s.C(f))) + } + } +} + +// Desc applies the given fields in DESC order. +func Desc(fields ...string) func(*sql.Selector) { + return func(s *sql.Selector) { + for _, f := range fields { + if err := checkColumn(s.TableName(), f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) + } + s.OrderBy(sql.Desc(s.C(f))) + } + } +} + +// AggregateFunc applies an aggregation step on the group-by traversal/selector. +type AggregateFunc func(*sql.Selector) string + +// As is a pseudo aggregation function for renaming another other functions with custom names. For example: +// +// GroupBy(field1, field2). +// Aggregate(ent.As(ent.Sum(field1), "sum_field1"), (ent.As(ent.Sum(field2), "sum_field2")). +// Scan(ctx, &v) +func As(fn AggregateFunc, end string) AggregateFunc { + return func(s *sql.Selector) string { + return sql.As(fn(s), end) + } +} + +// Count applies the "count" aggregation function on each group. +func Count() AggregateFunc { + return func(s *sql.Selector) string { + return sql.Count("*") + } +} + +// Max applies the "max" aggregation function on the given field of each group. +func Max(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Max(s.C(field)) + } +} + +// Mean applies the "mean" aggregation function on the given field of each group. +func Mean(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Avg(s.C(field)) + } +} + +// Min applies the "min" aggregation function on the given field of each group. +func Min(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Min(s.C(field)) + } +} + +// Sum applies the "sum" aggregation function on the given field of each group. +func Sum(field string) AggregateFunc { + return func(s *sql.Selector) string { + if err := checkColumn(s.TableName(), field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) + return "" + } + return sql.Sum(s.C(field)) + } +} + +// ValidationError returns when validating a field or edge fails. +type ValidationError struct { + Name string // Field or edge name. + err error +} + +// Error implements the error interface. +func (e *ValidationError) Error() string { + return e.err.Error() +} + +// Unwrap implements the errors.Wrapper interface. +func (e *ValidationError) Unwrap() error { + return e.err +} + +// IsValidationError returns a boolean indicating whether the error is a validation error. +func IsValidationError(err error) bool { + if err == nil { + return false + } + var e *ValidationError + return errors.As(err, &e) +} + +// NotFoundError returns when trying to fetch a specific entity and it was not found in the database. +type NotFoundError struct { + label string +} + +// Error implements the error interface. +func (e *NotFoundError) Error() string { + return "ent: " + e.label + " not found" +} + +// IsNotFound returns a boolean indicating whether the error is a not found error. +func IsNotFound(err error) bool { + if err == nil { + return false + } + var e *NotFoundError + return errors.As(err, &e) +} + +// MaskNotFound masks not found error. +func MaskNotFound(err error) error { + if IsNotFound(err) { + return nil + } + return err +} + +// NotSingularError returns when trying to fetch a singular entity and more then one was found in the database. +type NotSingularError struct { + label string +} + +// Error implements the error interface. +func (e *NotSingularError) Error() string { + return "ent: " + e.label + " not singular" +} + +// IsNotSingular returns a boolean indicating whether the error is a not singular error. +func IsNotSingular(err error) bool { + if err == nil { + return false + } + var e *NotSingularError + return errors.As(err, &e) +} + +// NotLoadedError returns when trying to get a node that was not loaded by the query. +type NotLoadedError struct { + edge string +} + +// Error implements the error interface. +func (e *NotLoadedError) Error() string { + return "ent: " + e.edge + " edge was not loaded" +} + +// IsNotLoaded returns a boolean indicating whether the error is a not loaded error. +func IsNotLoaded(err error) bool { + if err == nil { + return false + } + var e *NotLoadedError + return errors.As(err, &e) +} + +// ConstraintError returns when trying to create/update one or more entities and +// one or more of their constraints failed. For example, violation of edge or +// field uniqueness. +type ConstraintError struct { + msg string + wrap error +} + +// Error implements the error interface. +func (e ConstraintError) Error() string { + return "ent: constraint failed: " + e.msg +} + +// Unwrap implements the errors.Wrapper interface. +func (e *ConstraintError) Unwrap() error { + return e.wrap +} + +// IsConstraintError returns a boolean indicating whether the error is a constraint failure. +func IsConstraintError(err error) bool { + if err == nil { + return false + } + var e *ConstraintError + return errors.As(err, &e) +} + +// selector embedded by the different Select/GroupBy builders. +type selector struct { + label string + flds *[]string + fns []AggregateFunc + scan func(context.Context, any) error +} + +// ScanX is like Scan, but panics if an error occurs. +func (s *selector) ScanX(ctx context.Context, v any) { + if err := s.scan(ctx, v); err != nil { + panic(err) + } +} + +// Strings returns list of strings from a selector. It is only allowed when selecting one field. +func (s *selector) Strings(ctx context.Context) ([]string, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Strings is not achievable when selecting more than 1 field") + } + var v []string + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// StringsX is like Strings, but panics if an error occurs. +func (s *selector) StringsX(ctx context.Context) []string { + v, err := s.Strings(ctx) + if err != nil { + panic(err) + } + return v +} + +// String returns a single string from a selector. It is only allowed when selecting one field. +func (s *selector) String(ctx context.Context) (_ string, err error) { + var v []string + if v, err = s.Strings(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Strings returned %d results when one was expected", len(v)) + } + return +} + +// StringX is like String, but panics if an error occurs. +func (s *selector) StringX(ctx context.Context) string { + v, err := s.String(ctx) + if err != nil { + panic(err) + } + return v +} + +// Ints returns list of ints from a selector. It is only allowed when selecting one field. +func (s *selector) Ints(ctx context.Context) ([]int, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Ints is not achievable when selecting more than 1 field") + } + var v []int + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// IntsX is like Ints, but panics if an error occurs. +func (s *selector) IntsX(ctx context.Context) []int { + v, err := s.Ints(ctx) + if err != nil { + panic(err) + } + return v +} + +// Int returns a single int from a selector. It is only allowed when selecting one field. +func (s *selector) Int(ctx context.Context) (_ int, err error) { + var v []int + if v, err = s.Ints(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Ints returned %d results when one was expected", len(v)) + } + return +} + +// IntX is like Int, but panics if an error occurs. +func (s *selector) IntX(ctx context.Context) int { + v, err := s.Int(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64s returns list of float64s from a selector. It is only allowed when selecting one field. +func (s *selector) Float64s(ctx context.Context) ([]float64, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Float64s is not achievable when selecting more than 1 field") + } + var v []float64 + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// Float64sX is like Float64s, but panics if an error occurs. +func (s *selector) Float64sX(ctx context.Context) []float64 { + v, err := s.Float64s(ctx) + if err != nil { + panic(err) + } + return v +} + +// Float64 returns a single float64 from a selector. It is only allowed when selecting one field. +func (s *selector) Float64(ctx context.Context) (_ float64, err error) { + var v []float64 + if v, err = s.Float64s(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Float64s returned %d results when one was expected", len(v)) + } + return +} + +// Float64X is like Float64, but panics if an error occurs. +func (s *selector) Float64X(ctx context.Context) float64 { + v, err := s.Float64(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bools returns list of bools from a selector. It is only allowed when selecting one field. +func (s *selector) Bools(ctx context.Context) ([]bool, error) { + if len(*s.flds) > 1 { + return nil, errors.New("ent: Bools is not achievable when selecting more than 1 field") + } + var v []bool + if err := s.scan(ctx, &v); err != nil { + return nil, err + } + return v, nil +} + +// BoolsX is like Bools, but panics if an error occurs. +func (s *selector) BoolsX(ctx context.Context) []bool { + v, err := s.Bools(ctx) + if err != nil { + panic(err) + } + return v +} + +// Bool returns a single bool from a selector. It is only allowed when selecting one field. +func (s *selector) Bool(ctx context.Context) (_ bool, err error) { + var v []bool + if v, err = s.Bools(ctx); err != nil { + return + } + switch len(v) { + case 1: + return v[0], nil + case 0: + err = &NotFoundError{s.label} + default: + err = fmt.Errorf("ent: Bools returned %d results when one was expected", len(v)) + } + return +} + +// BoolX is like Bool, but panics if an error occurs. +func (s *selector) BoolX(ctx context.Context) bool { + v, err := s.Bool(ctx) + if err != nil { + panic(err) + } + return v +} + +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := any(m).(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + +// queryHook describes an internal hook for the different sqlAll methods. +type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/ent/enttest/enttest.go b/ent/enttest/enttest.go new file mode 100644 index 0000000000000..76d672fb6362c --- /dev/null +++ b/ent/enttest/enttest.go @@ -0,0 +1,84 @@ +// Code generated by ent, DO NOT EDIT. + +package enttest + +import ( + "context" + + "github.com/usememos/memos/ent" + // required by schema hooks. + _ "github.com/usememos/memos/ent/runtime" + + "entgo.io/ent/dialect/sql/schema" + "github.com/usememos/memos/ent/migrate" +) + +type ( + // TestingT is the interface that is shared between + // testing.T and testing.B and used by enttest. + TestingT interface { + FailNow() + Error(...any) + } + + // Option configures client creation. + Option func(*options) + + options struct { + opts []ent.Option + migrateOpts []schema.MigrateOption + } +) + +// WithOptions forwards options to client creation. +func WithOptions(opts ...ent.Option) Option { + return func(o *options) { + o.opts = append(o.opts, opts...) + } +} + +// WithMigrateOptions forwards options to auto migration. +func WithMigrateOptions(opts ...schema.MigrateOption) Option { + return func(o *options) { + o.migrateOpts = append(o.migrateOpts, opts...) + } +} + +func newOptions(opts []Option) *options { + o := &options{} + for _, opt := range opts { + opt(o) + } + return o +} + +// Open calls ent.Open and auto-run migration. +func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *ent.Client { + o := newOptions(opts) + c, err := ent.Open(driverName, dataSourceName, o.opts...) + if err != nil { + t.Error(err) + t.FailNow() + } + migrateSchema(t, c, o) + return c +} + +// NewClient calls ent.NewClient and auto-run migration. +func NewClient(t TestingT, opts ...Option) *ent.Client { + o := newOptions(opts) + c := ent.NewClient(o.opts...) + migrateSchema(t, c, o) + return c +} +func migrateSchema(t TestingT, c *ent.Client, o *options) { + tables, err := schema.CopyTables(migrate.Tables) + if err != nil { + t.Error(err) + t.FailNow() + } + if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil { + t.Error(err) + t.FailNow() + } +} diff --git a/ent/generate.go b/ent/generate.go new file mode 100644 index 0000000000000..8d3fdfdc1cd6b --- /dev/null +++ b/ent/generate.go @@ -0,0 +1,3 @@ +package ent + +//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema diff --git a/ent/hook/hook.go b/ent/hook/hook.go new file mode 100644 index 0000000000000..90a0e73781fe1 --- /dev/null +++ b/ent/hook/hook.go @@ -0,0 +1,211 @@ +// Code generated by ent, DO NOT EDIT. + +package hook + +import ( + "context" + "fmt" + + "github.com/usememos/memos/ent" +) + +// The MemoFunc type is an adapter to allow the use of ordinary +// function as Memo mutator. +type MemoFunc func(context.Context, *ent.MemoMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f MemoFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.MemoMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MemoMutation", m) +} + +// The MemoRelationFunc type is an adapter to allow the use of ordinary +// function as MemoRelation mutator. +type MemoRelationFunc func(context.Context, *ent.MemoRelationMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f MemoRelationFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.MemoRelationMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MemoRelationMutation", m) +} + +// Condition is a hook condition function. +type Condition func(context.Context, ent.Mutation) bool + +// And groups conditions with the AND operator. +func And(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m ent.Mutation) bool { + if !first(ctx, m) || !second(ctx, m) { + return false + } + for _, cond := range rest { + if !cond(ctx, m) { + return false + } + } + return true + } +} + +// Or groups conditions with the OR operator. +func Or(first, second Condition, rest ...Condition) Condition { + return func(ctx context.Context, m ent.Mutation) bool { + if first(ctx, m) || second(ctx, m) { + return true + } + for _, cond := range rest { + if cond(ctx, m) { + return true + } + } + return false + } +} + +// Not negates a given condition. +func Not(cond Condition) Condition { + return func(ctx context.Context, m ent.Mutation) bool { + return !cond(ctx, m) + } +} + +// HasOp is a condition testing mutation operation. +func HasOp(op ent.Op) Condition { + return func(_ context.Context, m ent.Mutation) bool { + return m.Op().Is(op) + } +} + +// HasAddedFields is a condition validating `.AddedField` on fields. +func HasAddedFields(field string, fields ...string) Condition { + return func(_ context.Context, m ent.Mutation) bool { + if _, exists := m.AddedField(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.AddedField(field); !exists { + return false + } + } + return true + } +} + +// HasClearedFields is a condition validating `.FieldCleared` on fields. +func HasClearedFields(field string, fields ...string) Condition { + return func(_ context.Context, m ent.Mutation) bool { + if exists := m.FieldCleared(field); !exists { + return false + } + for _, field := range fields { + if exists := m.FieldCleared(field); !exists { + return false + } + } + return true + } +} + +// HasFields is a condition validating `.Field` on fields. +func HasFields(field string, fields ...string) Condition { + return func(_ context.Context, m ent.Mutation) bool { + if _, exists := m.Field(field); !exists { + return false + } + for _, field := range fields { + if _, exists := m.Field(field); !exists { + return false + } + } + return true + } +} + +// If executes the given hook under condition. +// +// hook.If(ComputeAverage, And(HasFields(...), HasAddedFields(...))) +func If(hk ent.Hook, cond Condition) ent.Hook { + return func(next ent.Mutator) ent.Mutator { + return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if cond(ctx, m) { + return hk(next).Mutate(ctx, m) + } + return next.Mutate(ctx, m) + }) + } +} + +// On executes the given hook only for the given operation. +// +// hook.On(Log, ent.Delete|ent.Create) +func On(hk ent.Hook, op ent.Op) ent.Hook { + return If(hk, HasOp(op)) +} + +// Unless skips the given hook only for the given operation. +// +// hook.Unless(Log, ent.Update|ent.UpdateOne) +func Unless(hk ent.Hook, op ent.Op) ent.Hook { + return If(hk, Not(HasOp(op))) +} + +// FixedError is a hook returning a fixed error. +func FixedError(err error) ent.Hook { + return func(ent.Mutator) ent.Mutator { + return ent.MutateFunc(func(context.Context, ent.Mutation) (ent.Value, error) { + return nil, err + }) + } +} + +// Reject returns a hook that rejects all operations that match op. +// +// func (T) Hooks() []ent.Hook { +// return []ent.Hook{ +// Reject(ent.Delete|ent.Update), +// } +// } +func Reject(op ent.Op) ent.Hook { + hk := FixedError(fmt.Errorf("%s operation is not allowed", op)) + return On(hk, op) +} + +// Chain acts as a list of hooks and is effectively immutable. +// Once created, it will always hold the same set of hooks in the same order. +type Chain struct { + hooks []ent.Hook +} + +// NewChain creates a new chain of hooks. +func NewChain(hooks ...ent.Hook) Chain { + return Chain{append([]ent.Hook(nil), hooks...)} +} + +// Hook chains the list of hooks and returns the final hook. +func (c Chain) Hook() ent.Hook { + return func(mutator ent.Mutator) ent.Mutator { + for i := len(c.hooks) - 1; i >= 0; i-- { + mutator = c.hooks[i](mutator) + } + return mutator + } +} + +// Append extends a chain, adding the specified hook +// as the last ones in the mutation flow. +func (c Chain) Append(hooks ...ent.Hook) Chain { + newHooks := make([]ent.Hook, 0, len(c.hooks)+len(hooks)) + newHooks = append(newHooks, c.hooks...) + newHooks = append(newHooks, hooks...) + return Chain{newHooks} +} + +// Extend extends a chain, adding the specified chain +// as the last ones in the mutation flow. +func (c Chain) Extend(chain Chain) Chain { + return c.Append(chain.hooks...) +} diff --git a/ent/memo.go b/ent/memo.go new file mode 100644 index 0000000000000..03b35e0738f6f --- /dev/null +++ b/ent/memo.go @@ -0,0 +1,214 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/usememos/memos/ent/memo" +) + +// Memo is the model entity for the Memo schema. +type Memo struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // ResourceName holds the value of the "resource_name" field. + ResourceName string `json:"resource_name,omitempty"` + // CreatorID holds the value of the "creator_id" field. + CreatorID int `json:"creator_id,omitempty"` + // CreatedTs holds the value of the "created_ts" field. + CreatedTs time.Time `json:"created_ts,omitempty"` + // UpdatedTs holds the value of the "updated_ts" field. + UpdatedTs time.Time `json:"updated_ts,omitempty"` + // RowStatus holds the value of the "row_status" field. + RowStatus string `json:"row_status,omitempty"` + // Content holds the value of the "content" field. + Content string `json:"content,omitempty"` + // Visibility holds the value of the "visibility" field. + Visibility string `json:"visibility,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the MemoQuery when eager-loading is set. + Edges MemoEdges `json:"edges"` + selectValues sql.SelectValues +} + +// MemoEdges holds the relations/edges for other nodes in the graph. +type MemoEdges struct { + // RelatedMemo holds the value of the related_memo edge. + RelatedMemo []*Memo `json:"related_memo,omitempty"` + // MemoRelation holds the value of the memo_relation edge. + MemoRelation []*MemoRelation `json:"memo_relation,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// RelatedMemoOrErr returns the RelatedMemo value or an error if the edge +// was not loaded in eager-loading. +func (e MemoEdges) RelatedMemoOrErr() ([]*Memo, error) { + if e.loadedTypes[0] { + return e.RelatedMemo, nil + } + return nil, &NotLoadedError{edge: "related_memo"} +} + +// MemoRelationOrErr returns the MemoRelation value or an error if the edge +// was not loaded in eager-loading. +func (e MemoEdges) MemoRelationOrErr() ([]*MemoRelation, error) { + if e.loadedTypes[1] { + return e.MemoRelation, nil + } + return nil, &NotLoadedError{edge: "memo_relation"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*Memo) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case memo.FieldID, memo.FieldCreatorID: + values[i] = new(sql.NullInt64) + case memo.FieldResourceName, memo.FieldRowStatus, memo.FieldContent, memo.FieldVisibility: + values[i] = new(sql.NullString) + case memo.FieldCreatedTs, memo.FieldUpdatedTs: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the Memo fields. +func (m *Memo) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case memo.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + m.ID = int(value.Int64) + case memo.FieldResourceName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field resource_name", values[i]) + } else if value.Valid { + m.ResourceName = value.String + } + case memo.FieldCreatorID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field creator_id", values[i]) + } else if value.Valid { + m.CreatorID = int(value.Int64) + } + case memo.FieldCreatedTs: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_ts", values[i]) + } else if value.Valid { + m.CreatedTs = value.Time + } + case memo.FieldUpdatedTs: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_ts", values[i]) + } else if value.Valid { + m.UpdatedTs = value.Time + } + case memo.FieldRowStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field row_status", values[i]) + } else if value.Valid { + m.RowStatus = value.String + } + case memo.FieldContent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field content", values[i]) + } else if value.Valid { + m.Content = value.String + } + case memo.FieldVisibility: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field visibility", values[i]) + } else if value.Valid { + m.Visibility = value.String + } + default: + m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the Memo. +// This includes values selected through modifiers, order, etc. +func (m *Memo) Value(name string) (ent.Value, error) { + return m.selectValues.Get(name) +} + +// QueryRelatedMemo queries the "related_memo" edge of the Memo entity. +func (m *Memo) QueryRelatedMemo() *MemoQuery { + return NewMemoClient(m.config).QueryRelatedMemo(m) +} + +// QueryMemoRelation queries the "memo_relation" edge of the Memo entity. +func (m *Memo) QueryMemoRelation() *MemoRelationQuery { + return NewMemoClient(m.config).QueryMemoRelation(m) +} + +// Update returns a builder for updating this Memo. +// Note that you need to call Memo.Unwrap() before calling this method if this Memo +// was returned from a transaction, and the transaction was committed or rolled back. +func (m *Memo) Update() *MemoUpdateOne { + return NewMemoClient(m.config).UpdateOne(m) +} + +// Unwrap unwraps the Memo entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (m *Memo) Unwrap() *Memo { + _tx, ok := m.config.driver.(*txDriver) + if !ok { + panic("ent: Memo is not a transactional entity") + } + m.config.driver = _tx.drv + return m +} + +// String implements the fmt.Stringer. +func (m *Memo) String() string { + var builder strings.Builder + builder.WriteString("Memo(") + builder.WriteString(fmt.Sprintf("id=%v, ", m.ID)) + builder.WriteString("resource_name=") + builder.WriteString(m.ResourceName) + builder.WriteString(", ") + builder.WriteString("creator_id=") + builder.WriteString(fmt.Sprintf("%v", m.CreatorID)) + builder.WriteString(", ") + builder.WriteString("created_ts=") + builder.WriteString(m.CreatedTs.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_ts=") + builder.WriteString(m.UpdatedTs.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("row_status=") + builder.WriteString(m.RowStatus) + builder.WriteString(", ") + builder.WriteString("content=") + builder.WriteString(m.Content) + builder.WriteString(", ") + builder.WriteString("visibility=") + builder.WriteString(m.Visibility) + builder.WriteByte(')') + return builder.String() +} + +// Memos is a parsable slice of Memo. +type Memos []*Memo diff --git a/ent/memo/memo.go b/ent/memo/memo.go new file mode 100644 index 0000000000000..52dac6c9beed1 --- /dev/null +++ b/ent/memo/memo.go @@ -0,0 +1,172 @@ +// Code generated by ent, DO NOT EDIT. + +package memo + +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the memo type in the database. + Label = "memo" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldResourceName holds the string denoting the resource_name field in the database. + FieldResourceName = "resource_name" + // FieldCreatorID holds the string denoting the creator_id field in the database. + FieldCreatorID = "creator_id" + // FieldCreatedTs holds the string denoting the created_ts field in the database. + FieldCreatedTs = "created_ts" + // FieldUpdatedTs holds the string denoting the updated_ts field in the database. + FieldUpdatedTs = "updated_ts" + // FieldRowStatus holds the string denoting the row_status field in the database. + FieldRowStatus = "row_status" + // FieldContent holds the string denoting the content field in the database. + FieldContent = "content" + // FieldVisibility holds the string denoting the visibility field in the database. + FieldVisibility = "visibility" + // EdgeRelatedMemo holds the string denoting the related_memo edge name in mutations. + EdgeRelatedMemo = "related_memo" + // EdgeMemoRelation holds the string denoting the memo_relation edge name in mutations. + EdgeMemoRelation = "memo_relation" + // Table holds the table name of the memo in the database. + Table = "memos" + // RelatedMemoTable is the table that holds the related_memo relation/edge. The primary key declared below. + RelatedMemoTable = "memo_relations" + // MemoRelationTable is the table that holds the memo_relation relation/edge. + MemoRelationTable = "memo_relations" + // MemoRelationInverseTable is the table name for the MemoRelation entity. + // It exists in this package in order to avoid circular dependency with the "memorelation" package. + MemoRelationInverseTable = "memo_relations" + // MemoRelationColumn is the table column denoting the memo_relation relation/edge. + MemoRelationColumn = "memo_id" +) + +// Columns holds all SQL columns for memo fields. +var Columns = []string{ + FieldID, + FieldResourceName, + FieldCreatorID, + FieldCreatedTs, + FieldUpdatedTs, + FieldRowStatus, + FieldContent, + FieldVisibility, +} + +var ( + // RelatedMemoPrimaryKey and RelatedMemoColumn2 are the table columns denoting the + // primary key for the related_memo relation (M2M). + RelatedMemoPrimaryKey = []string{"memo_id", "related_memo_id"} +) + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // ResourceNameValidator is a validator for the "resource_name" field. It is called by the builders before save. + ResourceNameValidator func(string) error + // CreatorIDValidator is a validator for the "creator_id" field. It is called by the builders before save. + CreatorIDValidator func(int) error + // RowStatusValidator is a validator for the "row_status" field. It is called by the builders before save. + RowStatusValidator func(string) error + // DefaultContent holds the default value on creation for the "content" field. + DefaultContent string + // VisibilityValidator is a validator for the "visibility" field. It is called by the builders before save. + VisibilityValidator func(string) error + // IDValidator is a validator for the "id" field. It is called by the builders before save. + IDValidator func(int) error +) + +// OrderOption defines the ordering options for the Memo queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByResourceName orders the results by the resource_name field. +func ByResourceName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResourceName, opts...).ToFunc() +} + +// ByCreatorID orders the results by the creator_id field. +func ByCreatorID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatorID, opts...).ToFunc() +} + +// ByCreatedTs orders the results by the created_ts field. +func ByCreatedTs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedTs, opts...).ToFunc() +} + +// ByUpdatedTs orders the results by the updated_ts field. +func ByUpdatedTs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedTs, opts...).ToFunc() +} + +// ByRowStatus orders the results by the row_status field. +func ByRowStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRowStatus, opts...).ToFunc() +} + +// ByContent orders the results by the content field. +func ByContent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldContent, opts...).ToFunc() +} + +// ByVisibility orders the results by the visibility field. +func ByVisibility(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVisibility, opts...).ToFunc() +} + +// ByRelatedMemoCount orders the results by related_memo count. +func ByRelatedMemoCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newRelatedMemoStep(), opts...) + } +} + +// ByRelatedMemo orders the results by related_memo terms. +func ByRelatedMemo(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newRelatedMemoStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByMemoRelationCount orders the results by memo_relation count. +func ByMemoRelationCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newMemoRelationStep(), opts...) + } +} + +// ByMemoRelation orders the results by memo_relation terms. +func ByMemoRelation(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMemoRelationStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newRelatedMemoStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, RelatedMemoTable, RelatedMemoPrimaryKey...), + ) +} +func newMemoRelationStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MemoRelationInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, MemoRelationTable, MemoRelationColumn), + ) +} diff --git a/ent/memo/where.go b/ent/memo/where.go new file mode 100644 index 0000000000000..a1edabf43af95 --- /dev/null +++ b/ent/memo/where.go @@ -0,0 +1,532 @@ +// Code generated by ent, DO NOT EDIT. + +package memo + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/usememos/memos/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldID, id)) +} + +// ResourceName applies equality check predicate on the "resource_name" field. It's identical to ResourceNameEQ. +func ResourceName(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldResourceName, v)) +} + +// CreatorID applies equality check predicate on the "creator_id" field. It's identical to CreatorIDEQ. +func CreatorID(v int) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldCreatorID, v)) +} + +// CreatedTs applies equality check predicate on the "created_ts" field. It's identical to CreatedTsEQ. +func CreatedTs(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldCreatedTs, v)) +} + +// UpdatedTs applies equality check predicate on the "updated_ts" field. It's identical to UpdatedTsEQ. +func UpdatedTs(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldUpdatedTs, v)) +} + +// RowStatus applies equality check predicate on the "row_status" field. It's identical to RowStatusEQ. +func RowStatus(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldRowStatus, v)) +} + +// Content applies equality check predicate on the "content" field. It's identical to ContentEQ. +func Content(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldContent, v)) +} + +// Visibility applies equality check predicate on the "visibility" field. It's identical to VisibilityEQ. +func Visibility(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldVisibility, v)) +} + +// ResourceNameEQ applies the EQ predicate on the "resource_name" field. +func ResourceNameEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldResourceName, v)) +} + +// ResourceNameNEQ applies the NEQ predicate on the "resource_name" field. +func ResourceNameNEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldResourceName, v)) +} + +// ResourceNameIn applies the In predicate on the "resource_name" field. +func ResourceNameIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldResourceName, vs...)) +} + +// ResourceNameNotIn applies the NotIn predicate on the "resource_name" field. +func ResourceNameNotIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldResourceName, vs...)) +} + +// ResourceNameGT applies the GT predicate on the "resource_name" field. +func ResourceNameGT(v string) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldResourceName, v)) +} + +// ResourceNameGTE applies the GTE predicate on the "resource_name" field. +func ResourceNameGTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldResourceName, v)) +} + +// ResourceNameLT applies the LT predicate on the "resource_name" field. +func ResourceNameLT(v string) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldResourceName, v)) +} + +// ResourceNameLTE applies the LTE predicate on the "resource_name" field. +func ResourceNameLTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldResourceName, v)) +} + +// ResourceNameContains applies the Contains predicate on the "resource_name" field. +func ResourceNameContains(v string) predicate.Memo { + return predicate.Memo(sql.FieldContains(FieldResourceName, v)) +} + +// ResourceNameHasPrefix applies the HasPrefix predicate on the "resource_name" field. +func ResourceNameHasPrefix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasPrefix(FieldResourceName, v)) +} + +// ResourceNameHasSuffix applies the HasSuffix predicate on the "resource_name" field. +func ResourceNameHasSuffix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasSuffix(FieldResourceName, v)) +} + +// ResourceNameEqualFold applies the EqualFold predicate on the "resource_name" field. +func ResourceNameEqualFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldEqualFold(FieldResourceName, v)) +} + +// ResourceNameContainsFold applies the ContainsFold predicate on the "resource_name" field. +func ResourceNameContainsFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldContainsFold(FieldResourceName, v)) +} + +// CreatorIDEQ applies the EQ predicate on the "creator_id" field. +func CreatorIDEQ(v int) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldCreatorID, v)) +} + +// CreatorIDNEQ applies the NEQ predicate on the "creator_id" field. +func CreatorIDNEQ(v int) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldCreatorID, v)) +} + +// CreatorIDIn applies the In predicate on the "creator_id" field. +func CreatorIDIn(vs ...int) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldCreatorID, vs...)) +} + +// CreatorIDNotIn applies the NotIn predicate on the "creator_id" field. +func CreatorIDNotIn(vs ...int) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldCreatorID, vs...)) +} + +// CreatorIDGT applies the GT predicate on the "creator_id" field. +func CreatorIDGT(v int) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldCreatorID, v)) +} + +// CreatorIDGTE applies the GTE predicate on the "creator_id" field. +func CreatorIDGTE(v int) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldCreatorID, v)) +} + +// CreatorIDLT applies the LT predicate on the "creator_id" field. +func CreatorIDLT(v int) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldCreatorID, v)) +} + +// CreatorIDLTE applies the LTE predicate on the "creator_id" field. +func CreatorIDLTE(v int) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldCreatorID, v)) +} + +// CreatedTsEQ applies the EQ predicate on the "created_ts" field. +func CreatedTsEQ(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldCreatedTs, v)) +} + +// CreatedTsNEQ applies the NEQ predicate on the "created_ts" field. +func CreatedTsNEQ(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldCreatedTs, v)) +} + +// CreatedTsIn applies the In predicate on the "created_ts" field. +func CreatedTsIn(vs ...time.Time) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldCreatedTs, vs...)) +} + +// CreatedTsNotIn applies the NotIn predicate on the "created_ts" field. +func CreatedTsNotIn(vs ...time.Time) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldCreatedTs, vs...)) +} + +// CreatedTsGT applies the GT predicate on the "created_ts" field. +func CreatedTsGT(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldCreatedTs, v)) +} + +// CreatedTsGTE applies the GTE predicate on the "created_ts" field. +func CreatedTsGTE(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldCreatedTs, v)) +} + +// CreatedTsLT applies the LT predicate on the "created_ts" field. +func CreatedTsLT(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldCreatedTs, v)) +} + +// CreatedTsLTE applies the LTE predicate on the "created_ts" field. +func CreatedTsLTE(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldCreatedTs, v)) +} + +// UpdatedTsEQ applies the EQ predicate on the "updated_ts" field. +func UpdatedTsEQ(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldUpdatedTs, v)) +} + +// UpdatedTsNEQ applies the NEQ predicate on the "updated_ts" field. +func UpdatedTsNEQ(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldUpdatedTs, v)) +} + +// UpdatedTsIn applies the In predicate on the "updated_ts" field. +func UpdatedTsIn(vs ...time.Time) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldUpdatedTs, vs...)) +} + +// UpdatedTsNotIn applies the NotIn predicate on the "updated_ts" field. +func UpdatedTsNotIn(vs ...time.Time) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldUpdatedTs, vs...)) +} + +// UpdatedTsGT applies the GT predicate on the "updated_ts" field. +func UpdatedTsGT(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldUpdatedTs, v)) +} + +// UpdatedTsGTE applies the GTE predicate on the "updated_ts" field. +func UpdatedTsGTE(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldUpdatedTs, v)) +} + +// UpdatedTsLT applies the LT predicate on the "updated_ts" field. +func UpdatedTsLT(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldUpdatedTs, v)) +} + +// UpdatedTsLTE applies the LTE predicate on the "updated_ts" field. +func UpdatedTsLTE(v time.Time) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldUpdatedTs, v)) +} + +// RowStatusEQ applies the EQ predicate on the "row_status" field. +func RowStatusEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldRowStatus, v)) +} + +// RowStatusNEQ applies the NEQ predicate on the "row_status" field. +func RowStatusNEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldRowStatus, v)) +} + +// RowStatusIn applies the In predicate on the "row_status" field. +func RowStatusIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldRowStatus, vs...)) +} + +// RowStatusNotIn applies the NotIn predicate on the "row_status" field. +func RowStatusNotIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldRowStatus, vs...)) +} + +// RowStatusGT applies the GT predicate on the "row_status" field. +func RowStatusGT(v string) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldRowStatus, v)) +} + +// RowStatusGTE applies the GTE predicate on the "row_status" field. +func RowStatusGTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldRowStatus, v)) +} + +// RowStatusLT applies the LT predicate on the "row_status" field. +func RowStatusLT(v string) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldRowStatus, v)) +} + +// RowStatusLTE applies the LTE predicate on the "row_status" field. +func RowStatusLTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldRowStatus, v)) +} + +// RowStatusContains applies the Contains predicate on the "row_status" field. +func RowStatusContains(v string) predicate.Memo { + return predicate.Memo(sql.FieldContains(FieldRowStatus, v)) +} + +// RowStatusHasPrefix applies the HasPrefix predicate on the "row_status" field. +func RowStatusHasPrefix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasPrefix(FieldRowStatus, v)) +} + +// RowStatusHasSuffix applies the HasSuffix predicate on the "row_status" field. +func RowStatusHasSuffix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasSuffix(FieldRowStatus, v)) +} + +// RowStatusEqualFold applies the EqualFold predicate on the "row_status" field. +func RowStatusEqualFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldEqualFold(FieldRowStatus, v)) +} + +// RowStatusContainsFold applies the ContainsFold predicate on the "row_status" field. +func RowStatusContainsFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldContainsFold(FieldRowStatus, v)) +} + +// ContentEQ applies the EQ predicate on the "content" field. +func ContentEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldContent, v)) +} + +// ContentNEQ applies the NEQ predicate on the "content" field. +func ContentNEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldContent, v)) +} + +// ContentIn applies the In predicate on the "content" field. +func ContentIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldContent, vs...)) +} + +// ContentNotIn applies the NotIn predicate on the "content" field. +func ContentNotIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldContent, vs...)) +} + +// ContentGT applies the GT predicate on the "content" field. +func ContentGT(v string) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldContent, v)) +} + +// ContentGTE applies the GTE predicate on the "content" field. +func ContentGTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldContent, v)) +} + +// ContentLT applies the LT predicate on the "content" field. +func ContentLT(v string) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldContent, v)) +} + +// ContentLTE applies the LTE predicate on the "content" field. +func ContentLTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldContent, v)) +} + +// ContentContains applies the Contains predicate on the "content" field. +func ContentContains(v string) predicate.Memo { + return predicate.Memo(sql.FieldContains(FieldContent, v)) +} + +// ContentHasPrefix applies the HasPrefix predicate on the "content" field. +func ContentHasPrefix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasPrefix(FieldContent, v)) +} + +// ContentHasSuffix applies the HasSuffix predicate on the "content" field. +func ContentHasSuffix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasSuffix(FieldContent, v)) +} + +// ContentEqualFold applies the EqualFold predicate on the "content" field. +func ContentEqualFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldEqualFold(FieldContent, v)) +} + +// ContentContainsFold applies the ContainsFold predicate on the "content" field. +func ContentContainsFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldContainsFold(FieldContent, v)) +} + +// VisibilityEQ applies the EQ predicate on the "visibility" field. +func VisibilityEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldEQ(FieldVisibility, v)) +} + +// VisibilityNEQ applies the NEQ predicate on the "visibility" field. +func VisibilityNEQ(v string) predicate.Memo { + return predicate.Memo(sql.FieldNEQ(FieldVisibility, v)) +} + +// VisibilityIn applies the In predicate on the "visibility" field. +func VisibilityIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldIn(FieldVisibility, vs...)) +} + +// VisibilityNotIn applies the NotIn predicate on the "visibility" field. +func VisibilityNotIn(vs ...string) predicate.Memo { + return predicate.Memo(sql.FieldNotIn(FieldVisibility, vs...)) +} + +// VisibilityGT applies the GT predicate on the "visibility" field. +func VisibilityGT(v string) predicate.Memo { + return predicate.Memo(sql.FieldGT(FieldVisibility, v)) +} + +// VisibilityGTE applies the GTE predicate on the "visibility" field. +func VisibilityGTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldGTE(FieldVisibility, v)) +} + +// VisibilityLT applies the LT predicate on the "visibility" field. +func VisibilityLT(v string) predicate.Memo { + return predicate.Memo(sql.FieldLT(FieldVisibility, v)) +} + +// VisibilityLTE applies the LTE predicate on the "visibility" field. +func VisibilityLTE(v string) predicate.Memo { + return predicate.Memo(sql.FieldLTE(FieldVisibility, v)) +} + +// VisibilityContains applies the Contains predicate on the "visibility" field. +func VisibilityContains(v string) predicate.Memo { + return predicate.Memo(sql.FieldContains(FieldVisibility, v)) +} + +// VisibilityHasPrefix applies the HasPrefix predicate on the "visibility" field. +func VisibilityHasPrefix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasPrefix(FieldVisibility, v)) +} + +// VisibilityHasSuffix applies the HasSuffix predicate on the "visibility" field. +func VisibilityHasSuffix(v string) predicate.Memo { + return predicate.Memo(sql.FieldHasSuffix(FieldVisibility, v)) +} + +// VisibilityEqualFold applies the EqualFold predicate on the "visibility" field. +func VisibilityEqualFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldEqualFold(FieldVisibility, v)) +} + +// VisibilityContainsFold applies the ContainsFold predicate on the "visibility" field. +func VisibilityContainsFold(v string) predicate.Memo { + return predicate.Memo(sql.FieldContainsFold(FieldVisibility, v)) +} + +// HasRelatedMemo applies the HasEdge predicate on the "related_memo" edge. +func HasRelatedMemo() predicate.Memo { + return predicate.Memo(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, RelatedMemoTable, RelatedMemoPrimaryKey...), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasRelatedMemoWith applies the HasEdge predicate on the "related_memo" edge with a given conditions (other predicates). +func HasRelatedMemoWith(preds ...predicate.Memo) predicate.Memo { + return predicate.Memo(func(s *sql.Selector) { + step := newRelatedMemoStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasMemoRelation applies the HasEdge predicate on the "memo_relation" edge. +func HasMemoRelation() predicate.Memo { + return predicate.Memo(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, MemoRelationTable, MemoRelationColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasMemoRelationWith applies the HasEdge predicate on the "memo_relation" edge with a given conditions (other predicates). +func HasMemoRelationWith(preds ...predicate.MemoRelation) predicate.Memo { + return predicate.Memo(func(s *sql.Selector) { + step := newMemoRelationStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.Memo) predicate.Memo { + return predicate.Memo(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.Memo) predicate.Memo { + return predicate.Memo(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.Memo) predicate.Memo { + return predicate.Memo(sql.NotPredicates(p)) +} diff --git a/ent/memo_create.go b/ent/memo_create.go new file mode 100644 index 0000000000000..9b5b7a84ac3eb --- /dev/null +++ b/ent/memo_create.go @@ -0,0 +1,380 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" +) + +// MemoCreate is the builder for creating a Memo entity. +type MemoCreate struct { + config + mutation *MemoMutation + hooks []Hook +} + +// SetResourceName sets the "resource_name" field. +func (mc *MemoCreate) SetResourceName(s string) *MemoCreate { + mc.mutation.SetResourceName(s) + return mc +} + +// SetCreatorID sets the "creator_id" field. +func (mc *MemoCreate) SetCreatorID(i int) *MemoCreate { + mc.mutation.SetCreatorID(i) + return mc +} + +// SetCreatedTs sets the "created_ts" field. +func (mc *MemoCreate) SetCreatedTs(t time.Time) *MemoCreate { + mc.mutation.SetCreatedTs(t) + return mc +} + +// SetUpdatedTs sets the "updated_ts" field. +func (mc *MemoCreate) SetUpdatedTs(t time.Time) *MemoCreate { + mc.mutation.SetUpdatedTs(t) + return mc +} + +// SetRowStatus sets the "row_status" field. +func (mc *MemoCreate) SetRowStatus(s string) *MemoCreate { + mc.mutation.SetRowStatus(s) + return mc +} + +// SetContent sets the "content" field. +func (mc *MemoCreate) SetContent(s string) *MemoCreate { + mc.mutation.SetContent(s) + return mc +} + +// SetNillableContent sets the "content" field if the given value is not nil. +func (mc *MemoCreate) SetNillableContent(s *string) *MemoCreate { + if s != nil { + mc.SetContent(*s) + } + return mc +} + +// SetVisibility sets the "visibility" field. +func (mc *MemoCreate) SetVisibility(s string) *MemoCreate { + mc.mutation.SetVisibility(s) + return mc +} + +// SetID sets the "id" field. +func (mc *MemoCreate) SetID(i int) *MemoCreate { + mc.mutation.SetID(i) + return mc +} + +// AddRelatedMemoIDs adds the "related_memo" edge to the Memo entity by IDs. +func (mc *MemoCreate) AddRelatedMemoIDs(ids ...int) *MemoCreate { + mc.mutation.AddRelatedMemoIDs(ids...) + return mc +} + +// AddRelatedMemo adds the "related_memo" edges to the Memo entity. +func (mc *MemoCreate) AddRelatedMemo(m ...*Memo) *MemoCreate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return mc.AddRelatedMemoIDs(ids...) +} + +// AddMemoRelationIDs adds the "memo_relation" edge to the MemoRelation entity by IDs. +func (mc *MemoCreate) AddMemoRelationIDs(ids ...int) *MemoCreate { + mc.mutation.AddMemoRelationIDs(ids...) + return mc +} + +// AddMemoRelation adds the "memo_relation" edges to the MemoRelation entity. +func (mc *MemoCreate) AddMemoRelation(m ...*MemoRelation) *MemoCreate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return mc.AddMemoRelationIDs(ids...) +} + +// Mutation returns the MemoMutation object of the builder. +func (mc *MemoCreate) Mutation() *MemoMutation { + return mc.mutation +} + +// Save creates the Memo in the database. +func (mc *MemoCreate) Save(ctx context.Context) (*Memo, error) { + mc.defaults() + return withHooks(ctx, mc.sqlSave, mc.mutation, mc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (mc *MemoCreate) SaveX(ctx context.Context) *Memo { + v, err := mc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mc *MemoCreate) Exec(ctx context.Context) error { + _, err := mc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mc *MemoCreate) ExecX(ctx context.Context) { + if err := mc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (mc *MemoCreate) defaults() { + if _, ok := mc.mutation.Content(); !ok { + v := memo.DefaultContent + mc.mutation.SetContent(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mc *MemoCreate) check() error { + if _, ok := mc.mutation.ResourceName(); !ok { + return &ValidationError{Name: "resource_name", err: errors.New(`ent: missing required field "Memo.resource_name"`)} + } + if v, ok := mc.mutation.ResourceName(); ok { + if err := memo.ResourceNameValidator(v); err != nil { + return &ValidationError{Name: "resource_name", err: fmt.Errorf(`ent: validator failed for field "Memo.resource_name": %w`, err)} + } + } + if _, ok := mc.mutation.CreatorID(); !ok { + return &ValidationError{Name: "creator_id", err: errors.New(`ent: missing required field "Memo.creator_id"`)} + } + if v, ok := mc.mutation.CreatorID(); ok { + if err := memo.CreatorIDValidator(v); err != nil { + return &ValidationError{Name: "creator_id", err: fmt.Errorf(`ent: validator failed for field "Memo.creator_id": %w`, err)} + } + } + if _, ok := mc.mutation.CreatedTs(); !ok { + return &ValidationError{Name: "created_ts", err: errors.New(`ent: missing required field "Memo.created_ts"`)} + } + if _, ok := mc.mutation.UpdatedTs(); !ok { + return &ValidationError{Name: "updated_ts", err: errors.New(`ent: missing required field "Memo.updated_ts"`)} + } + if _, ok := mc.mutation.RowStatus(); !ok { + return &ValidationError{Name: "row_status", err: errors.New(`ent: missing required field "Memo.row_status"`)} + } + if v, ok := mc.mutation.RowStatus(); ok { + if err := memo.RowStatusValidator(v); err != nil { + return &ValidationError{Name: "row_status", err: fmt.Errorf(`ent: validator failed for field "Memo.row_status": %w`, err)} + } + } + if _, ok := mc.mutation.Content(); !ok { + return &ValidationError{Name: "content", err: errors.New(`ent: missing required field "Memo.content"`)} + } + if _, ok := mc.mutation.Visibility(); !ok { + return &ValidationError{Name: "visibility", err: errors.New(`ent: missing required field "Memo.visibility"`)} + } + if v, ok := mc.mutation.Visibility(); ok { + if err := memo.VisibilityValidator(v); err != nil { + return &ValidationError{Name: "visibility", err: fmt.Errorf(`ent: validator failed for field "Memo.visibility": %w`, err)} + } + } + if v, ok := mc.mutation.ID(); ok { + if err := memo.IDValidator(v); err != nil { + return &ValidationError{Name: "id", err: fmt.Errorf(`ent: validator failed for field "Memo.id": %w`, err)} + } + } + return nil +} + +func (mc *MemoCreate) sqlSave(ctx context.Context) (*Memo, error) { + if err := mc.check(); err != nil { + return nil, err + } + _node, _spec := mc.createSpec() + if err := sqlgraph.CreateNode(ctx, mc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + if _spec.ID.Value != _node.ID { + id := _spec.ID.Value.(int64) + _node.ID = int(id) + } + mc.mutation.id = &_node.ID + mc.mutation.done = true + return _node, nil +} + +func (mc *MemoCreate) createSpec() (*Memo, *sqlgraph.CreateSpec) { + var ( + _node = &Memo{config: mc.config} + _spec = sqlgraph.NewCreateSpec(memo.Table, sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt)) + ) + if id, ok := mc.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = id + } + if value, ok := mc.mutation.ResourceName(); ok { + _spec.SetField(memo.FieldResourceName, field.TypeString, value) + _node.ResourceName = value + } + if value, ok := mc.mutation.CreatorID(); ok { + _spec.SetField(memo.FieldCreatorID, field.TypeInt, value) + _node.CreatorID = value + } + if value, ok := mc.mutation.CreatedTs(); ok { + _spec.SetField(memo.FieldCreatedTs, field.TypeTime, value) + _node.CreatedTs = value + } + if value, ok := mc.mutation.UpdatedTs(); ok { + _spec.SetField(memo.FieldUpdatedTs, field.TypeTime, value) + _node.UpdatedTs = value + } + if value, ok := mc.mutation.RowStatus(); ok { + _spec.SetField(memo.FieldRowStatus, field.TypeString, value) + _node.RowStatus = value + } + if value, ok := mc.mutation.Content(); ok { + _spec.SetField(memo.FieldContent, field.TypeString, value) + _node.Content = value + } + if value, ok := mc.mutation.Visibility(); ok { + _spec.SetField(memo.FieldVisibility, field.TypeString, value) + _node.Visibility = value + } + if nodes := mc.mutation.RelatedMemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := mc.mutation.MemoRelationIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// MemoCreateBulk is the builder for creating many Memo entities in bulk. +type MemoCreateBulk struct { + config + err error + builders []*MemoCreate +} + +// Save creates the Memo entities in the database. +func (mcb *MemoCreateBulk) Save(ctx context.Context) ([]*Memo, error) { + if mcb.err != nil { + return nil, mcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(mcb.builders)) + nodes := make([]*Memo, len(mcb.builders)) + mutators := make([]Mutator, len(mcb.builders)) + for i := range mcb.builders { + func(i int, root context.Context) { + builder := mcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*MemoMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, mcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, mcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil && nodes[i].ID == 0 { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, mcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (mcb *MemoCreateBulk) SaveX(ctx context.Context) []*Memo { + v, err := mcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mcb *MemoCreateBulk) Exec(ctx context.Context) error { + _, err := mcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mcb *MemoCreateBulk) ExecX(ctx context.Context) { + if err := mcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/memo_delete.go b/ent/memo_delete.go new file mode 100644 index 0000000000000..5a9ed562c0860 --- /dev/null +++ b/ent/memo_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/predicate" +) + +// MemoDelete is the builder for deleting a Memo entity. +type MemoDelete struct { + config + hooks []Hook + mutation *MemoMutation +} + +// Where appends a list predicates to the MemoDelete builder. +func (md *MemoDelete) Where(ps ...predicate.Memo) *MemoDelete { + md.mutation.Where(ps...) + return md +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (md *MemoDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, md.sqlExec, md.mutation, md.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (md *MemoDelete) ExecX(ctx context.Context) int { + n, err := md.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (md *MemoDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(memo.Table, sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt)) + if ps := md.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, md.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + md.mutation.done = true + return affected, err +} + +// MemoDeleteOne is the builder for deleting a single Memo entity. +type MemoDeleteOne struct { + md *MemoDelete +} + +// Where appends a list predicates to the MemoDelete builder. +func (mdo *MemoDeleteOne) Where(ps ...predicate.Memo) *MemoDeleteOne { + mdo.md.mutation.Where(ps...) + return mdo +} + +// Exec executes the deletion query. +func (mdo *MemoDeleteOne) Exec(ctx context.Context) error { + n, err := mdo.md.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{memo.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (mdo *MemoDeleteOne) ExecX(ctx context.Context) { + if err := mdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/memo_query.go b/ent/memo_query.go new file mode 100644 index 0000000000000..bd7bc3feaa1fc --- /dev/null +++ b/ent/memo_query.go @@ -0,0 +1,709 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" + "github.com/usememos/memos/ent/predicate" +) + +// MemoQuery is the builder for querying Memo entities. +type MemoQuery struct { + config + ctx *QueryContext + order []memo.OrderOption + inters []Interceptor + predicates []predicate.Memo + withRelatedMemo *MemoQuery + withMemoRelation *MemoRelationQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the MemoQuery builder. +func (mq *MemoQuery) Where(ps ...predicate.Memo) *MemoQuery { + mq.predicates = append(mq.predicates, ps...) + return mq +} + +// Limit the number of records to be returned by this query. +func (mq *MemoQuery) Limit(limit int) *MemoQuery { + mq.ctx.Limit = &limit + return mq +} + +// Offset to start from. +func (mq *MemoQuery) Offset(offset int) *MemoQuery { + mq.ctx.Offset = &offset + return mq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (mq *MemoQuery) Unique(unique bool) *MemoQuery { + mq.ctx.Unique = &unique + return mq +} + +// Order specifies how the records should be ordered. +func (mq *MemoQuery) Order(o ...memo.OrderOption) *MemoQuery { + mq.order = append(mq.order, o...) + return mq +} + +// QueryRelatedMemo chains the current query on the "related_memo" edge. +func (mq *MemoQuery) QueryRelatedMemo() *MemoQuery { + query := (&MemoClient{config: mq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := mq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(memo.Table, memo.FieldID, selector), + sqlgraph.To(memo.Table, memo.FieldID), + sqlgraph.Edge(sqlgraph.M2M, false, memo.RelatedMemoTable, memo.RelatedMemoPrimaryKey...), + ) + fromU = sqlgraph.SetNeighbors(mq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryMemoRelation chains the current query on the "memo_relation" edge. +func (mq *MemoQuery) QueryMemoRelation() *MemoRelationQuery { + query := (&MemoRelationClient{config: mq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := mq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(memo.Table, memo.FieldID, selector), + sqlgraph.To(memorelation.Table, memorelation.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, memo.MemoRelationTable, memo.MemoRelationColumn), + ) + fromU = sqlgraph.SetNeighbors(mq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first Memo entity from the query. +// Returns a *NotFoundError when no Memo was found. +func (mq *MemoQuery) First(ctx context.Context) (*Memo, error) { + nodes, err := mq.Limit(1).All(setContextOp(ctx, mq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{memo.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (mq *MemoQuery) FirstX(ctx context.Context) *Memo { + node, err := mq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first Memo ID from the query. +// Returns a *NotFoundError when no Memo ID was found. +func (mq *MemoQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(1).IDs(setContextOp(ctx, mq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{memo.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (mq *MemoQuery) FirstIDX(ctx context.Context) int { + id, err := mq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single Memo entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one Memo entity is found. +// Returns a *NotFoundError when no Memo entities are found. +func (mq *MemoQuery) Only(ctx context.Context) (*Memo, error) { + nodes, err := mq.Limit(2).All(setContextOp(ctx, mq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{memo.Label} + default: + return nil, &NotSingularError{memo.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (mq *MemoQuery) OnlyX(ctx context.Context) *Memo { + node, err := mq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only Memo ID in the query. +// Returns a *NotSingularError when more than one Memo ID is found. +// Returns a *NotFoundError when no entities are found. +func (mq *MemoQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mq.Limit(2).IDs(setContextOp(ctx, mq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{memo.Label} + default: + err = &NotSingularError{memo.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (mq *MemoQuery) OnlyIDX(ctx context.Context) int { + id, err := mq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of Memos. +func (mq *MemoQuery) All(ctx context.Context) ([]*Memo, error) { + ctx = setContextOp(ctx, mq.ctx, "All") + if err := mq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*Memo, *MemoQuery]() + return withInterceptors[[]*Memo](ctx, mq, qr, mq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (mq *MemoQuery) AllX(ctx context.Context) []*Memo { + nodes, err := mq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of Memo IDs. +func (mq *MemoQuery) IDs(ctx context.Context) (ids []int, err error) { + if mq.ctx.Unique == nil && mq.path != nil { + mq.Unique(true) + } + ctx = setContextOp(ctx, mq.ctx, "IDs") + if err = mq.Select(memo.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (mq *MemoQuery) IDsX(ctx context.Context) []int { + ids, err := mq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (mq *MemoQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mq.ctx, "Count") + if err := mq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, mq, querierCount[*MemoQuery](), mq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (mq *MemoQuery) CountX(ctx context.Context) int { + count, err := mq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (mq *MemoQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, mq.ctx, "Exist") + switch _, err := mq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (mq *MemoQuery) ExistX(ctx context.Context) bool { + exist, err := mq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the MemoQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (mq *MemoQuery) Clone() *MemoQuery { + if mq == nil { + return nil + } + return &MemoQuery{ + config: mq.config, + ctx: mq.ctx.Clone(), + order: append([]memo.OrderOption{}, mq.order...), + inters: append([]Interceptor{}, mq.inters...), + predicates: append([]predicate.Memo{}, mq.predicates...), + withRelatedMemo: mq.withRelatedMemo.Clone(), + withMemoRelation: mq.withMemoRelation.Clone(), + // clone intermediate query. + sql: mq.sql.Clone(), + path: mq.path, + } +} + +// WithRelatedMemo tells the query-builder to eager-load the nodes that are connected to +// the "related_memo" edge. The optional arguments are used to configure the query builder of the edge. +func (mq *MemoQuery) WithRelatedMemo(opts ...func(*MemoQuery)) *MemoQuery { + query := (&MemoClient{config: mq.config}).Query() + for _, opt := range opts { + opt(query) + } + mq.withRelatedMemo = query + return mq +} + +// WithMemoRelation tells the query-builder to eager-load the nodes that are connected to +// the "memo_relation" edge. The optional arguments are used to configure the query builder of the edge. +func (mq *MemoQuery) WithMemoRelation(opts ...func(*MemoRelationQuery)) *MemoQuery { + query := (&MemoRelationClient{config: mq.config}).Query() + for _, opt := range opts { + opt(query) + } + mq.withMemoRelation = query + return mq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// ResourceName string `json:"resource_name,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.Memo.Query(). +// GroupBy(memo.FieldResourceName). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (mq *MemoQuery) GroupBy(field string, fields ...string) *MemoGroupBy { + mq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MemoGroupBy{build: mq} + grbuild.flds = &mq.ctx.Fields + grbuild.label = memo.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// ResourceName string `json:"resource_name,omitempty"` +// } +// +// client.Memo.Query(). +// Select(memo.FieldResourceName). +// Scan(ctx, &v) +func (mq *MemoQuery) Select(fields ...string) *MemoSelect { + mq.ctx.Fields = append(mq.ctx.Fields, fields...) + sbuild := &MemoSelect{MemoQuery: mq} + sbuild.label = memo.Label + sbuild.flds, sbuild.scan = &mq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MemoSelect configured with the given aggregations. +func (mq *MemoQuery) Aggregate(fns ...AggregateFunc) *MemoSelect { + return mq.Select().Aggregate(fns...) +} + +func (mq *MemoQuery) prepareQuery(ctx context.Context) error { + for _, inter := range mq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mq); err != nil { + return err + } + } + } + for _, f := range mq.ctx.Fields { + if !memo.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if mq.path != nil { + prev, err := mq.path(ctx) + if err != nil { + return err + } + mq.sql = prev + } + return nil +} + +func (mq *MemoQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Memo, error) { + var ( + nodes = []*Memo{} + _spec = mq.querySpec() + loadedTypes = [2]bool{ + mq.withRelatedMemo != nil, + mq.withMemoRelation != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*Memo).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &Memo{config: mq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := mq.withRelatedMemo; query != nil { + if err := mq.loadRelatedMemo(ctx, query, nodes, + func(n *Memo) { n.Edges.RelatedMemo = []*Memo{} }, + func(n *Memo, e *Memo) { n.Edges.RelatedMemo = append(n.Edges.RelatedMemo, e) }); err != nil { + return nil, err + } + } + if query := mq.withMemoRelation; query != nil { + if err := mq.loadMemoRelation(ctx, query, nodes, + func(n *Memo) { n.Edges.MemoRelation = []*MemoRelation{} }, + func(n *Memo, e *MemoRelation) { n.Edges.MemoRelation = append(n.Edges.MemoRelation, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (mq *MemoQuery) loadRelatedMemo(ctx context.Context, query *MemoQuery, nodes []*Memo, init func(*Memo), assign func(*Memo, *Memo)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Memo) + nids := make(map[int]map[*Memo]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(memo.RelatedMemoTable) + s.Join(joinT).On(s.C(memo.FieldID), joinT.C(memo.RelatedMemoPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(memo.RelatedMemoPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(memo.RelatedMemoPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + if err := query.prepareQuery(ctx); err != nil { + return err + } + qr := QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + return query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]any, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]any{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []any) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Memo]struct{}{byID[outValue]: {}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + }) + neighbors, err := withInterceptors[[]*Memo](ctx, query, qr, query.inters) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "related_memo" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (mq *MemoQuery) loadMemoRelation(ctx context.Context, query *MemoRelationQuery, nodes []*Memo, init func(*Memo), assign func(*Memo, *MemoRelation)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Memo) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(memorelation.FieldMemoID) + } + query.Where(predicate.MemoRelation(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(memo.MemoRelationColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.MemoID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "memo_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (mq *MemoQuery) sqlCount(ctx context.Context) (int, error) { + _spec := mq.querySpec() + _spec.Node.Columns = mq.ctx.Fields + if len(mq.ctx.Fields) > 0 { + _spec.Unique = mq.ctx.Unique != nil && *mq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, mq.driver, _spec) +} + +func (mq *MemoQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(memo.Table, memo.Columns, sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt)) + _spec.From = mq.sql + if unique := mq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if mq.path != nil { + _spec.Unique = true + } + if fields := mq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, memo.FieldID) + for i := range fields { + if fields[i] != memo.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := mq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := mq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := mq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := mq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (mq *MemoQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(mq.driver.Dialect()) + t1 := builder.Table(memo.Table) + columns := mq.ctx.Fields + if len(columns) == 0 { + columns = memo.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if mq.sql != nil { + selector = mq.sql + selector.Select(selector.Columns(columns...)...) + } + if mq.ctx.Unique != nil && *mq.ctx.Unique { + selector.Distinct() + } + for _, p := range mq.predicates { + p(selector) + } + for _, p := range mq.order { + p(selector) + } + if offset := mq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := mq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// MemoGroupBy is the group-by builder for Memo entities. +type MemoGroupBy struct { + selector + build *MemoQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (mgb *MemoGroupBy) Aggregate(fns ...AggregateFunc) *MemoGroupBy { + mgb.fns = append(mgb.fns, fns...) + return mgb +} + +// Scan applies the selector query and scans the result into the given value. +func (mgb *MemoGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mgb.build.ctx, "GroupBy") + if err := mgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MemoQuery, *MemoGroupBy](ctx, mgb.build, mgb, mgb.build.inters, v) +} + +func (mgb *MemoGroupBy) sqlScan(ctx context.Context, root *MemoQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mgb.fns)) + for _, fn := range mgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mgb.flds)+len(mgb.fns)) + for _, f := range *mgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*mgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := mgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// MemoSelect is the builder for selecting fields of Memo entities. +type MemoSelect struct { + *MemoQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ms *MemoSelect) Aggregate(fns ...AggregateFunc) *MemoSelect { + ms.fns = append(ms.fns, fns...) + return ms +} + +// Scan applies the selector query and scans the result into the given value. +func (ms *MemoSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ms.ctx, "Select") + if err := ms.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MemoQuery, *MemoSelect](ctx, ms.MemoQuery, ms, ms.inters, v) +} + +func (ms *MemoSelect) sqlScan(ctx context.Context, root *MemoQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ms.fns)) + for _, fn := range ms.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ms.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := ms.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/memo_update.go b/ent/memo_update.go new file mode 100644 index 0000000000000..5183440a4d71d --- /dev/null +++ b/ent/memo_update.go @@ -0,0 +1,815 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" + "github.com/usememos/memos/ent/predicate" +) + +// MemoUpdate is the builder for updating Memo entities. +type MemoUpdate struct { + config + hooks []Hook + mutation *MemoMutation +} + +// Where appends a list predicates to the MemoUpdate builder. +func (mu *MemoUpdate) Where(ps ...predicate.Memo) *MemoUpdate { + mu.mutation.Where(ps...) + return mu +} + +// SetResourceName sets the "resource_name" field. +func (mu *MemoUpdate) SetResourceName(s string) *MemoUpdate { + mu.mutation.SetResourceName(s) + return mu +} + +// SetNillableResourceName sets the "resource_name" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableResourceName(s *string) *MemoUpdate { + if s != nil { + mu.SetResourceName(*s) + } + return mu +} + +// SetCreatorID sets the "creator_id" field. +func (mu *MemoUpdate) SetCreatorID(i int) *MemoUpdate { + mu.mutation.ResetCreatorID() + mu.mutation.SetCreatorID(i) + return mu +} + +// SetNillableCreatorID sets the "creator_id" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableCreatorID(i *int) *MemoUpdate { + if i != nil { + mu.SetCreatorID(*i) + } + return mu +} + +// AddCreatorID adds i to the "creator_id" field. +func (mu *MemoUpdate) AddCreatorID(i int) *MemoUpdate { + mu.mutation.AddCreatorID(i) + return mu +} + +// SetCreatedTs sets the "created_ts" field. +func (mu *MemoUpdate) SetCreatedTs(t time.Time) *MemoUpdate { + mu.mutation.SetCreatedTs(t) + return mu +} + +// SetNillableCreatedTs sets the "created_ts" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableCreatedTs(t *time.Time) *MemoUpdate { + if t != nil { + mu.SetCreatedTs(*t) + } + return mu +} + +// SetUpdatedTs sets the "updated_ts" field. +func (mu *MemoUpdate) SetUpdatedTs(t time.Time) *MemoUpdate { + mu.mutation.SetUpdatedTs(t) + return mu +} + +// SetNillableUpdatedTs sets the "updated_ts" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableUpdatedTs(t *time.Time) *MemoUpdate { + if t != nil { + mu.SetUpdatedTs(*t) + } + return mu +} + +// SetRowStatus sets the "row_status" field. +func (mu *MemoUpdate) SetRowStatus(s string) *MemoUpdate { + mu.mutation.SetRowStatus(s) + return mu +} + +// SetNillableRowStatus sets the "row_status" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableRowStatus(s *string) *MemoUpdate { + if s != nil { + mu.SetRowStatus(*s) + } + return mu +} + +// SetContent sets the "content" field. +func (mu *MemoUpdate) SetContent(s string) *MemoUpdate { + mu.mutation.SetContent(s) + return mu +} + +// SetNillableContent sets the "content" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableContent(s *string) *MemoUpdate { + if s != nil { + mu.SetContent(*s) + } + return mu +} + +// SetVisibility sets the "visibility" field. +func (mu *MemoUpdate) SetVisibility(s string) *MemoUpdate { + mu.mutation.SetVisibility(s) + return mu +} + +// SetNillableVisibility sets the "visibility" field if the given value is not nil. +func (mu *MemoUpdate) SetNillableVisibility(s *string) *MemoUpdate { + if s != nil { + mu.SetVisibility(*s) + } + return mu +} + +// AddRelatedMemoIDs adds the "related_memo" edge to the Memo entity by IDs. +func (mu *MemoUpdate) AddRelatedMemoIDs(ids ...int) *MemoUpdate { + mu.mutation.AddRelatedMemoIDs(ids...) + return mu +} + +// AddRelatedMemo adds the "related_memo" edges to the Memo entity. +func (mu *MemoUpdate) AddRelatedMemo(m ...*Memo) *MemoUpdate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return mu.AddRelatedMemoIDs(ids...) +} + +// AddMemoRelationIDs adds the "memo_relation" edge to the MemoRelation entity by IDs. +func (mu *MemoUpdate) AddMemoRelationIDs(ids ...int) *MemoUpdate { + mu.mutation.AddMemoRelationIDs(ids...) + return mu +} + +// AddMemoRelation adds the "memo_relation" edges to the MemoRelation entity. +func (mu *MemoUpdate) AddMemoRelation(m ...*MemoRelation) *MemoUpdate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return mu.AddMemoRelationIDs(ids...) +} + +// Mutation returns the MemoMutation object of the builder. +func (mu *MemoUpdate) Mutation() *MemoMutation { + return mu.mutation +} + +// ClearRelatedMemo clears all "related_memo" edges to the Memo entity. +func (mu *MemoUpdate) ClearRelatedMemo() *MemoUpdate { + mu.mutation.ClearRelatedMemo() + return mu +} + +// RemoveRelatedMemoIDs removes the "related_memo" edge to Memo entities by IDs. +func (mu *MemoUpdate) RemoveRelatedMemoIDs(ids ...int) *MemoUpdate { + mu.mutation.RemoveRelatedMemoIDs(ids...) + return mu +} + +// RemoveRelatedMemo removes "related_memo" edges to Memo entities. +func (mu *MemoUpdate) RemoveRelatedMemo(m ...*Memo) *MemoUpdate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return mu.RemoveRelatedMemoIDs(ids...) +} + +// ClearMemoRelation clears all "memo_relation" edges to the MemoRelation entity. +func (mu *MemoUpdate) ClearMemoRelation() *MemoUpdate { + mu.mutation.ClearMemoRelation() + return mu +} + +// RemoveMemoRelationIDs removes the "memo_relation" edge to MemoRelation entities by IDs. +func (mu *MemoUpdate) RemoveMemoRelationIDs(ids ...int) *MemoUpdate { + mu.mutation.RemoveMemoRelationIDs(ids...) + return mu +} + +// RemoveMemoRelation removes "memo_relation" edges to MemoRelation entities. +func (mu *MemoUpdate) RemoveMemoRelation(m ...*MemoRelation) *MemoUpdate { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return mu.RemoveMemoRelationIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (mu *MemoUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, mu.sqlSave, mu.mutation, mu.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (mu *MemoUpdate) SaveX(ctx context.Context) int { + affected, err := mu.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (mu *MemoUpdate) Exec(ctx context.Context) error { + _, err := mu.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mu *MemoUpdate) ExecX(ctx context.Context) { + if err := mu.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mu *MemoUpdate) check() error { + if v, ok := mu.mutation.ResourceName(); ok { + if err := memo.ResourceNameValidator(v); err != nil { + return &ValidationError{Name: "resource_name", err: fmt.Errorf(`ent: validator failed for field "Memo.resource_name": %w`, err)} + } + } + if v, ok := mu.mutation.CreatorID(); ok { + if err := memo.CreatorIDValidator(v); err != nil { + return &ValidationError{Name: "creator_id", err: fmt.Errorf(`ent: validator failed for field "Memo.creator_id": %w`, err)} + } + } + if v, ok := mu.mutation.RowStatus(); ok { + if err := memo.RowStatusValidator(v); err != nil { + return &ValidationError{Name: "row_status", err: fmt.Errorf(`ent: validator failed for field "Memo.row_status": %w`, err)} + } + } + if v, ok := mu.mutation.Visibility(); ok { + if err := memo.VisibilityValidator(v); err != nil { + return &ValidationError{Name: "visibility", err: fmt.Errorf(`ent: validator failed for field "Memo.visibility": %w`, err)} + } + } + return nil +} + +func (mu *MemoUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := mu.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(memo.Table, memo.Columns, sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt)) + if ps := mu.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := mu.mutation.ResourceName(); ok { + _spec.SetField(memo.FieldResourceName, field.TypeString, value) + } + if value, ok := mu.mutation.CreatorID(); ok { + _spec.SetField(memo.FieldCreatorID, field.TypeInt, value) + } + if value, ok := mu.mutation.AddedCreatorID(); ok { + _spec.AddField(memo.FieldCreatorID, field.TypeInt, value) + } + if value, ok := mu.mutation.CreatedTs(); ok { + _spec.SetField(memo.FieldCreatedTs, field.TypeTime, value) + } + if value, ok := mu.mutation.UpdatedTs(); ok { + _spec.SetField(memo.FieldUpdatedTs, field.TypeTime, value) + } + if value, ok := mu.mutation.RowStatus(); ok { + _spec.SetField(memo.FieldRowStatus, field.TypeString, value) + } + if value, ok := mu.mutation.Content(); ok { + _spec.SetField(memo.FieldContent, field.TypeString, value) + } + if value, ok := mu.mutation.Visibility(); ok { + _spec.SetField(memo.FieldVisibility, field.TypeString, value) + } + if mu.mutation.RelatedMemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mu.mutation.RemovedRelatedMemoIDs(); len(nodes) > 0 && !mu.mutation.RelatedMemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mu.mutation.RelatedMemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if mu.mutation.MemoRelationCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mu.mutation.RemovedMemoRelationIDs(); len(nodes) > 0 && !mu.mutation.MemoRelationCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mu.mutation.MemoRelationIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, mu.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{memo.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + mu.mutation.done = true + return n, nil +} + +// MemoUpdateOne is the builder for updating a single Memo entity. +type MemoUpdateOne struct { + config + fields []string + hooks []Hook + mutation *MemoMutation +} + +// SetResourceName sets the "resource_name" field. +func (muo *MemoUpdateOne) SetResourceName(s string) *MemoUpdateOne { + muo.mutation.SetResourceName(s) + return muo +} + +// SetNillableResourceName sets the "resource_name" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableResourceName(s *string) *MemoUpdateOne { + if s != nil { + muo.SetResourceName(*s) + } + return muo +} + +// SetCreatorID sets the "creator_id" field. +func (muo *MemoUpdateOne) SetCreatorID(i int) *MemoUpdateOne { + muo.mutation.ResetCreatorID() + muo.mutation.SetCreatorID(i) + return muo +} + +// SetNillableCreatorID sets the "creator_id" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableCreatorID(i *int) *MemoUpdateOne { + if i != nil { + muo.SetCreatorID(*i) + } + return muo +} + +// AddCreatorID adds i to the "creator_id" field. +func (muo *MemoUpdateOne) AddCreatorID(i int) *MemoUpdateOne { + muo.mutation.AddCreatorID(i) + return muo +} + +// SetCreatedTs sets the "created_ts" field. +func (muo *MemoUpdateOne) SetCreatedTs(t time.Time) *MemoUpdateOne { + muo.mutation.SetCreatedTs(t) + return muo +} + +// SetNillableCreatedTs sets the "created_ts" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableCreatedTs(t *time.Time) *MemoUpdateOne { + if t != nil { + muo.SetCreatedTs(*t) + } + return muo +} + +// SetUpdatedTs sets the "updated_ts" field. +func (muo *MemoUpdateOne) SetUpdatedTs(t time.Time) *MemoUpdateOne { + muo.mutation.SetUpdatedTs(t) + return muo +} + +// SetNillableUpdatedTs sets the "updated_ts" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableUpdatedTs(t *time.Time) *MemoUpdateOne { + if t != nil { + muo.SetUpdatedTs(*t) + } + return muo +} + +// SetRowStatus sets the "row_status" field. +func (muo *MemoUpdateOne) SetRowStatus(s string) *MemoUpdateOne { + muo.mutation.SetRowStatus(s) + return muo +} + +// SetNillableRowStatus sets the "row_status" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableRowStatus(s *string) *MemoUpdateOne { + if s != nil { + muo.SetRowStatus(*s) + } + return muo +} + +// SetContent sets the "content" field. +func (muo *MemoUpdateOne) SetContent(s string) *MemoUpdateOne { + muo.mutation.SetContent(s) + return muo +} + +// SetNillableContent sets the "content" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableContent(s *string) *MemoUpdateOne { + if s != nil { + muo.SetContent(*s) + } + return muo +} + +// SetVisibility sets the "visibility" field. +func (muo *MemoUpdateOne) SetVisibility(s string) *MemoUpdateOne { + muo.mutation.SetVisibility(s) + return muo +} + +// SetNillableVisibility sets the "visibility" field if the given value is not nil. +func (muo *MemoUpdateOne) SetNillableVisibility(s *string) *MemoUpdateOne { + if s != nil { + muo.SetVisibility(*s) + } + return muo +} + +// AddRelatedMemoIDs adds the "related_memo" edge to the Memo entity by IDs. +func (muo *MemoUpdateOne) AddRelatedMemoIDs(ids ...int) *MemoUpdateOne { + muo.mutation.AddRelatedMemoIDs(ids...) + return muo +} + +// AddRelatedMemo adds the "related_memo" edges to the Memo entity. +func (muo *MemoUpdateOne) AddRelatedMemo(m ...*Memo) *MemoUpdateOne { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return muo.AddRelatedMemoIDs(ids...) +} + +// AddMemoRelationIDs adds the "memo_relation" edge to the MemoRelation entity by IDs. +func (muo *MemoUpdateOne) AddMemoRelationIDs(ids ...int) *MemoUpdateOne { + muo.mutation.AddMemoRelationIDs(ids...) + return muo +} + +// AddMemoRelation adds the "memo_relation" edges to the MemoRelation entity. +func (muo *MemoUpdateOne) AddMemoRelation(m ...*MemoRelation) *MemoUpdateOne { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return muo.AddMemoRelationIDs(ids...) +} + +// Mutation returns the MemoMutation object of the builder. +func (muo *MemoUpdateOne) Mutation() *MemoMutation { + return muo.mutation +} + +// ClearRelatedMemo clears all "related_memo" edges to the Memo entity. +func (muo *MemoUpdateOne) ClearRelatedMemo() *MemoUpdateOne { + muo.mutation.ClearRelatedMemo() + return muo +} + +// RemoveRelatedMemoIDs removes the "related_memo" edge to Memo entities by IDs. +func (muo *MemoUpdateOne) RemoveRelatedMemoIDs(ids ...int) *MemoUpdateOne { + muo.mutation.RemoveRelatedMemoIDs(ids...) + return muo +} + +// RemoveRelatedMemo removes "related_memo" edges to Memo entities. +func (muo *MemoUpdateOne) RemoveRelatedMemo(m ...*Memo) *MemoUpdateOne { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return muo.RemoveRelatedMemoIDs(ids...) +} + +// ClearMemoRelation clears all "memo_relation" edges to the MemoRelation entity. +func (muo *MemoUpdateOne) ClearMemoRelation() *MemoUpdateOne { + muo.mutation.ClearMemoRelation() + return muo +} + +// RemoveMemoRelationIDs removes the "memo_relation" edge to MemoRelation entities by IDs. +func (muo *MemoUpdateOne) RemoveMemoRelationIDs(ids ...int) *MemoUpdateOne { + muo.mutation.RemoveMemoRelationIDs(ids...) + return muo +} + +// RemoveMemoRelation removes "memo_relation" edges to MemoRelation entities. +func (muo *MemoUpdateOne) RemoveMemoRelation(m ...*MemoRelation) *MemoUpdateOne { + ids := make([]int, len(m)) + for i := range m { + ids[i] = m[i].ID + } + return muo.RemoveMemoRelationIDs(ids...) +} + +// Where appends a list predicates to the MemoUpdate builder. +func (muo *MemoUpdateOne) Where(ps ...predicate.Memo) *MemoUpdateOne { + muo.mutation.Where(ps...) + return muo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (muo *MemoUpdateOne) Select(field string, fields ...string) *MemoUpdateOne { + muo.fields = append([]string{field}, fields...) + return muo +} + +// Save executes the query and returns the updated Memo entity. +func (muo *MemoUpdateOne) Save(ctx context.Context) (*Memo, error) { + return withHooks(ctx, muo.sqlSave, muo.mutation, muo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (muo *MemoUpdateOne) SaveX(ctx context.Context) *Memo { + node, err := muo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (muo *MemoUpdateOne) Exec(ctx context.Context) error { + _, err := muo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (muo *MemoUpdateOne) ExecX(ctx context.Context) { + if err := muo.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (muo *MemoUpdateOne) check() error { + if v, ok := muo.mutation.ResourceName(); ok { + if err := memo.ResourceNameValidator(v); err != nil { + return &ValidationError{Name: "resource_name", err: fmt.Errorf(`ent: validator failed for field "Memo.resource_name": %w`, err)} + } + } + if v, ok := muo.mutation.CreatorID(); ok { + if err := memo.CreatorIDValidator(v); err != nil { + return &ValidationError{Name: "creator_id", err: fmt.Errorf(`ent: validator failed for field "Memo.creator_id": %w`, err)} + } + } + if v, ok := muo.mutation.RowStatus(); ok { + if err := memo.RowStatusValidator(v); err != nil { + return &ValidationError{Name: "row_status", err: fmt.Errorf(`ent: validator failed for field "Memo.row_status": %w`, err)} + } + } + if v, ok := muo.mutation.Visibility(); ok { + if err := memo.VisibilityValidator(v); err != nil { + return &ValidationError{Name: "visibility", err: fmt.Errorf(`ent: validator failed for field "Memo.visibility": %w`, err)} + } + } + return nil +} + +func (muo *MemoUpdateOne) sqlSave(ctx context.Context) (_node *Memo, err error) { + if err := muo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(memo.Table, memo.Columns, sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt)) + id, ok := muo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Memo.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := muo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, memo.FieldID) + for _, f := range fields { + if !memo.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != memo.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := muo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := muo.mutation.ResourceName(); ok { + _spec.SetField(memo.FieldResourceName, field.TypeString, value) + } + if value, ok := muo.mutation.CreatorID(); ok { + _spec.SetField(memo.FieldCreatorID, field.TypeInt, value) + } + if value, ok := muo.mutation.AddedCreatorID(); ok { + _spec.AddField(memo.FieldCreatorID, field.TypeInt, value) + } + if value, ok := muo.mutation.CreatedTs(); ok { + _spec.SetField(memo.FieldCreatedTs, field.TypeTime, value) + } + if value, ok := muo.mutation.UpdatedTs(); ok { + _spec.SetField(memo.FieldUpdatedTs, field.TypeTime, value) + } + if value, ok := muo.mutation.RowStatus(); ok { + _spec.SetField(memo.FieldRowStatus, field.TypeString, value) + } + if value, ok := muo.mutation.Content(); ok { + _spec.SetField(memo.FieldContent, field.TypeString, value) + } + if value, ok := muo.mutation.Visibility(); ok { + _spec.SetField(memo.FieldVisibility, field.TypeString, value) + } + if muo.mutation.RelatedMemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := muo.mutation.RemovedRelatedMemoIDs(); len(nodes) > 0 && !muo.mutation.RelatedMemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := muo.mutation.RelatedMemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2M, + Inverse: false, + Table: memo.RelatedMemoTable, + Columns: memo.RelatedMemoPrimaryKey, + Bidi: true, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if muo.mutation.MemoRelationCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := muo.mutation.RemovedMemoRelationIDs(); len(nodes) > 0 && !muo.mutation.MemoRelationCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := muo.mutation.MemoRelationIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: memo.MemoRelationTable, + Columns: []string{memo.MemoRelationColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &Memo{config: muo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, muo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{memo.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + muo.mutation.done = true + return _node, nil +} diff --git a/ent/memorelation.go b/ent/memorelation.go new file mode 100644 index 0000000000000..8b032f2841e9e --- /dev/null +++ b/ent/memorelation.go @@ -0,0 +1,176 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" +) + +// MemoRelation is the model entity for the MemoRelation schema. +type MemoRelation struct { + config `json:"-"` + // ID of the ent. + ID int `json:"id,omitempty"` + // Type holds the value of the "type" field. + Type string `json:"type,omitempty"` + // MemoID holds the value of the "memo_id" field. + MemoID int `json:"memo_id,omitempty"` + // RelatedMemoID holds the value of the "related_memo_id" field. + RelatedMemoID int `json:"related_memo_id,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the MemoRelationQuery when eager-loading is set. + Edges MemoRelationEdges `json:"edges"` + selectValues sql.SelectValues +} + +// MemoRelationEdges holds the relations/edges for other nodes in the graph. +type MemoRelationEdges struct { + // Memo holds the value of the memo edge. + Memo *Memo `json:"memo,omitempty"` + // RelatedMemo holds the value of the related_memo edge. + RelatedMemo *Memo `json:"related_memo,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// MemoOrErr returns the Memo value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e MemoRelationEdges) MemoOrErr() (*Memo, error) { + if e.loadedTypes[0] { + if e.Memo == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: memo.Label} + } + return e.Memo, nil + } + return nil, &NotLoadedError{edge: "memo"} +} + +// RelatedMemoOrErr returns the RelatedMemo value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e MemoRelationEdges) RelatedMemoOrErr() (*Memo, error) { + if e.loadedTypes[1] { + if e.RelatedMemo == nil { + // Edge was loaded but was not found. + return nil, &NotFoundError{label: memo.Label} + } + return e.RelatedMemo, nil + } + return nil, &NotLoadedError{edge: "related_memo"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*MemoRelation) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case memorelation.FieldID, memorelation.FieldMemoID, memorelation.FieldRelatedMemoID: + values[i] = new(sql.NullInt64) + case memorelation.FieldType: + values[i] = new(sql.NullString) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the MemoRelation fields. +func (mr *MemoRelation) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case memorelation.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + mr.ID = int(value.Int64) + case memorelation.FieldType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field type", values[i]) + } else if value.Valid { + mr.Type = value.String + } + case memorelation.FieldMemoID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field memo_id", values[i]) + } else if value.Valid { + mr.MemoID = int(value.Int64) + } + case memorelation.FieldRelatedMemoID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field related_memo_id", values[i]) + } else if value.Valid { + mr.RelatedMemoID = int(value.Int64) + } + default: + mr.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the MemoRelation. +// This includes values selected through modifiers, order, etc. +func (mr *MemoRelation) Value(name string) (ent.Value, error) { + return mr.selectValues.Get(name) +} + +// QueryMemo queries the "memo" edge of the MemoRelation entity. +func (mr *MemoRelation) QueryMemo() *MemoQuery { + return NewMemoRelationClient(mr.config).QueryMemo(mr) +} + +// QueryRelatedMemo queries the "related_memo" edge of the MemoRelation entity. +func (mr *MemoRelation) QueryRelatedMemo() *MemoQuery { + return NewMemoRelationClient(mr.config).QueryRelatedMemo(mr) +} + +// Update returns a builder for updating this MemoRelation. +// Note that you need to call MemoRelation.Unwrap() before calling this method if this MemoRelation +// was returned from a transaction, and the transaction was committed or rolled back. +func (mr *MemoRelation) Update() *MemoRelationUpdateOne { + return NewMemoRelationClient(mr.config).UpdateOne(mr) +} + +// Unwrap unwraps the MemoRelation entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (mr *MemoRelation) Unwrap() *MemoRelation { + _tx, ok := mr.config.driver.(*txDriver) + if !ok { + panic("ent: MemoRelation is not a transactional entity") + } + mr.config.driver = _tx.drv + return mr +} + +// String implements the fmt.Stringer. +func (mr *MemoRelation) String() string { + var builder strings.Builder + builder.WriteString("MemoRelation(") + builder.WriteString(fmt.Sprintf("id=%v, ", mr.ID)) + builder.WriteString("type=") + builder.WriteString(mr.Type) + builder.WriteString(", ") + builder.WriteString("memo_id=") + builder.WriteString(fmt.Sprintf("%v", mr.MemoID)) + builder.WriteString(", ") + builder.WriteString("related_memo_id=") + builder.WriteString(fmt.Sprintf("%v", mr.RelatedMemoID)) + builder.WriteByte(')') + return builder.String() +} + +// MemoRelations is a parsable slice of MemoRelation. +type MemoRelations []*MemoRelation diff --git a/ent/memorelation/memorelation.go b/ent/memorelation/memorelation.go new file mode 100644 index 0000000000000..924cf0d499356 --- /dev/null +++ b/ent/memorelation/memorelation.go @@ -0,0 +1,110 @@ +// Code generated by ent, DO NOT EDIT. + +package memorelation + +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the memorelation type in the database. + Label = "memo_relation" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldType holds the string denoting the type field in the database. + FieldType = "type" + // FieldMemoID holds the string denoting the memo_id field in the database. + FieldMemoID = "memo_id" + // FieldRelatedMemoID holds the string denoting the related_memo_id field in the database. + FieldRelatedMemoID = "related_memo_id" + // EdgeMemo holds the string denoting the memo edge name in mutations. + EdgeMemo = "memo" + // EdgeRelatedMemo holds the string denoting the related_memo edge name in mutations. + EdgeRelatedMemo = "related_memo" + // Table holds the table name of the memorelation in the database. + Table = "memo_relations" + // MemoTable is the table that holds the memo relation/edge. + MemoTable = "memo_relations" + // MemoInverseTable is the table name for the Memo entity. + // It exists in this package in order to avoid circular dependency with the "memo" package. + MemoInverseTable = "memos" + // MemoColumn is the table column denoting the memo relation/edge. + MemoColumn = "memo_id" + // RelatedMemoTable is the table that holds the related_memo relation/edge. + RelatedMemoTable = "memo_relations" + // RelatedMemoInverseTable is the table name for the Memo entity. + // It exists in this package in order to avoid circular dependency with the "memo" package. + RelatedMemoInverseTable = "memos" + // RelatedMemoColumn is the table column denoting the related_memo relation/edge. + RelatedMemoColumn = "related_memo_id" +) + +// Columns holds all SQL columns for memorelation fields. +var Columns = []string{ + FieldID, + FieldType, + FieldMemoID, + FieldRelatedMemoID, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// OrderOption defines the ordering options for the MemoRelation queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByMemoID orders the results by the memo_id field. +func ByMemoID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMemoID, opts...).ToFunc() +} + +// ByRelatedMemoID orders the results by the related_memo_id field. +func ByRelatedMemoID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRelatedMemoID, opts...).ToFunc() +} + +// ByMemoField orders the results by memo field. +func ByMemoField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMemoStep(), sql.OrderByField(field, opts...)) + } +} + +// ByRelatedMemoField orders the results by related_memo field. +func ByRelatedMemoField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newRelatedMemoStep(), sql.OrderByField(field, opts...)) + } +} +func newMemoStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MemoInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, MemoTable, MemoColumn), + ) +} +func newRelatedMemoStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(RelatedMemoInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, RelatedMemoTable, RelatedMemoColumn), + ) +} diff --git a/ent/memorelation/where.go b/ent/memorelation/where.go new file mode 100644 index 0000000000000..ec63d2ff8e072 --- /dev/null +++ b/ent/memorelation/where.go @@ -0,0 +1,235 @@ +// Code generated by ent, DO NOT EDIT. + +package memorelation + +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/usememos/memos/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldLTE(FieldID, id)) +} + +// Type applies equality check predicate on the "type" field. It's identical to TypeEQ. +func Type(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldType, v)) +} + +// MemoID applies equality check predicate on the "memo_id" field. It's identical to MemoIDEQ. +func MemoID(v int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldMemoID, v)) +} + +// RelatedMemoID applies equality check predicate on the "related_memo_id" field. It's identical to RelatedMemoIDEQ. +func RelatedMemoID(v int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldRelatedMemoID, v)) +} + +// TypeEQ applies the EQ predicate on the "type" field. +func TypeEQ(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldType, v)) +} + +// TypeNEQ applies the NEQ predicate on the "type" field. +func TypeNEQ(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNEQ(FieldType, v)) +} + +// TypeIn applies the In predicate on the "type" field. +func TypeIn(vs ...string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldIn(FieldType, vs...)) +} + +// TypeNotIn applies the NotIn predicate on the "type" field. +func TypeNotIn(vs ...string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNotIn(FieldType, vs...)) +} + +// TypeGT applies the GT predicate on the "type" field. +func TypeGT(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldGT(FieldType, v)) +} + +// TypeGTE applies the GTE predicate on the "type" field. +func TypeGTE(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldGTE(FieldType, v)) +} + +// TypeLT applies the LT predicate on the "type" field. +func TypeLT(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldLT(FieldType, v)) +} + +// TypeLTE applies the LTE predicate on the "type" field. +func TypeLTE(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldLTE(FieldType, v)) +} + +// TypeContains applies the Contains predicate on the "type" field. +func TypeContains(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldContains(FieldType, v)) +} + +// TypeHasPrefix applies the HasPrefix predicate on the "type" field. +func TypeHasPrefix(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldHasPrefix(FieldType, v)) +} + +// TypeHasSuffix applies the HasSuffix predicate on the "type" field. +func TypeHasSuffix(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldHasSuffix(FieldType, v)) +} + +// TypeEqualFold applies the EqualFold predicate on the "type" field. +func TypeEqualFold(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEqualFold(FieldType, v)) +} + +// TypeContainsFold applies the ContainsFold predicate on the "type" field. +func TypeContainsFold(v string) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldContainsFold(FieldType, v)) +} + +// MemoIDEQ applies the EQ predicate on the "memo_id" field. +func MemoIDEQ(v int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldMemoID, v)) +} + +// MemoIDNEQ applies the NEQ predicate on the "memo_id" field. +func MemoIDNEQ(v int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNEQ(FieldMemoID, v)) +} + +// MemoIDIn applies the In predicate on the "memo_id" field. +func MemoIDIn(vs ...int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldIn(FieldMemoID, vs...)) +} + +// MemoIDNotIn applies the NotIn predicate on the "memo_id" field. +func MemoIDNotIn(vs ...int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNotIn(FieldMemoID, vs...)) +} + +// RelatedMemoIDEQ applies the EQ predicate on the "related_memo_id" field. +func RelatedMemoIDEQ(v int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldEQ(FieldRelatedMemoID, v)) +} + +// RelatedMemoIDNEQ applies the NEQ predicate on the "related_memo_id" field. +func RelatedMemoIDNEQ(v int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNEQ(FieldRelatedMemoID, v)) +} + +// RelatedMemoIDIn applies the In predicate on the "related_memo_id" field. +func RelatedMemoIDIn(vs ...int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldIn(FieldRelatedMemoID, vs...)) +} + +// RelatedMemoIDNotIn applies the NotIn predicate on the "related_memo_id" field. +func RelatedMemoIDNotIn(vs ...int) predicate.MemoRelation { + return predicate.MemoRelation(sql.FieldNotIn(FieldRelatedMemoID, vs...)) +} + +// HasMemo applies the HasEdge predicate on the "memo" edge. +func HasMemo() predicate.MemoRelation { + return predicate.MemoRelation(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, MemoTable, MemoColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasMemoWith applies the HasEdge predicate on the "memo" edge with a given conditions (other predicates). +func HasMemoWith(preds ...predicate.Memo) predicate.MemoRelation { + return predicate.MemoRelation(func(s *sql.Selector) { + step := newMemoStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasRelatedMemo applies the HasEdge predicate on the "related_memo" edge. +func HasRelatedMemo() predicate.MemoRelation { + return predicate.MemoRelation(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, RelatedMemoTable, RelatedMemoColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasRelatedMemoWith applies the HasEdge predicate on the "related_memo" edge with a given conditions (other predicates). +func HasRelatedMemoWith(preds ...predicate.Memo) predicate.MemoRelation { + return predicate.MemoRelation(func(s *sql.Selector) { + step := newRelatedMemoStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.MemoRelation) predicate.MemoRelation { + return predicate.MemoRelation(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.MemoRelation) predicate.MemoRelation { + return predicate.MemoRelation(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.MemoRelation) predicate.MemoRelation { + return predicate.MemoRelation(sql.NotPredicates(p)) +} diff --git a/ent/memorelation_create.go b/ent/memorelation_create.go new file mode 100644 index 0000000000000..b262da89e9b4e --- /dev/null +++ b/ent/memorelation_create.go @@ -0,0 +1,252 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" +) + +// MemoRelationCreate is the builder for creating a MemoRelation entity. +type MemoRelationCreate struct { + config + mutation *MemoRelationMutation + hooks []Hook +} + +// SetType sets the "type" field. +func (mrc *MemoRelationCreate) SetType(s string) *MemoRelationCreate { + mrc.mutation.SetType(s) + return mrc +} + +// SetMemoID sets the "memo_id" field. +func (mrc *MemoRelationCreate) SetMemoID(i int) *MemoRelationCreate { + mrc.mutation.SetMemoID(i) + return mrc +} + +// SetRelatedMemoID sets the "related_memo_id" field. +func (mrc *MemoRelationCreate) SetRelatedMemoID(i int) *MemoRelationCreate { + mrc.mutation.SetRelatedMemoID(i) + return mrc +} + +// SetMemo sets the "memo" edge to the Memo entity. +func (mrc *MemoRelationCreate) SetMemo(m *Memo) *MemoRelationCreate { + return mrc.SetMemoID(m.ID) +} + +// SetRelatedMemo sets the "related_memo" edge to the Memo entity. +func (mrc *MemoRelationCreate) SetRelatedMemo(m *Memo) *MemoRelationCreate { + return mrc.SetRelatedMemoID(m.ID) +} + +// Mutation returns the MemoRelationMutation object of the builder. +func (mrc *MemoRelationCreate) Mutation() *MemoRelationMutation { + return mrc.mutation +} + +// Save creates the MemoRelation in the database. +func (mrc *MemoRelationCreate) Save(ctx context.Context) (*MemoRelation, error) { + return withHooks(ctx, mrc.sqlSave, mrc.mutation, mrc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (mrc *MemoRelationCreate) SaveX(ctx context.Context) *MemoRelation { + v, err := mrc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mrc *MemoRelationCreate) Exec(ctx context.Context) error { + _, err := mrc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mrc *MemoRelationCreate) ExecX(ctx context.Context) { + if err := mrc.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mrc *MemoRelationCreate) check() error { + if _, ok := mrc.mutation.GetType(); !ok { + return &ValidationError{Name: "type", err: errors.New(`ent: missing required field "MemoRelation.type"`)} + } + if _, ok := mrc.mutation.MemoID(); !ok { + return &ValidationError{Name: "memo_id", err: errors.New(`ent: missing required field "MemoRelation.memo_id"`)} + } + if _, ok := mrc.mutation.RelatedMemoID(); !ok { + return &ValidationError{Name: "related_memo_id", err: errors.New(`ent: missing required field "MemoRelation.related_memo_id"`)} + } + if _, ok := mrc.mutation.MemoID(); !ok { + return &ValidationError{Name: "memo", err: errors.New(`ent: missing required edge "MemoRelation.memo"`)} + } + if _, ok := mrc.mutation.RelatedMemoID(); !ok { + return &ValidationError{Name: "related_memo", err: errors.New(`ent: missing required edge "MemoRelation.related_memo"`)} + } + return nil +} + +func (mrc *MemoRelationCreate) sqlSave(ctx context.Context) (*MemoRelation, error) { + if err := mrc.check(); err != nil { + return nil, err + } + _node, _spec := mrc.createSpec() + if err := sqlgraph.CreateNode(ctx, mrc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int(id) + mrc.mutation.id = &_node.ID + mrc.mutation.done = true + return _node, nil +} + +func (mrc *MemoRelationCreate) createSpec() (*MemoRelation, *sqlgraph.CreateSpec) { + var ( + _node = &MemoRelation{config: mrc.config} + _spec = sqlgraph.NewCreateSpec(memorelation.Table, sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt)) + ) + if value, ok := mrc.mutation.GetType(); ok { + _spec.SetField(memorelation.FieldType, field.TypeString, value) + _node.Type = value + } + if nodes := mrc.mutation.MemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.MemoTable, + Columns: []string{memorelation.MemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.MemoID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := mrc.mutation.RelatedMemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.RelatedMemoTable, + Columns: []string{memorelation.RelatedMemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.RelatedMemoID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// MemoRelationCreateBulk is the builder for creating many MemoRelation entities in bulk. +type MemoRelationCreateBulk struct { + config + err error + builders []*MemoRelationCreate +} + +// Save creates the MemoRelation entities in the database. +func (mrcb *MemoRelationCreateBulk) Save(ctx context.Context) ([]*MemoRelation, error) { + if mrcb.err != nil { + return nil, mrcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(mrcb.builders)) + nodes := make([]*MemoRelation, len(mrcb.builders)) + mutators := make([]Mutator, len(mrcb.builders)) + for i := range mrcb.builders { + func(i int, root context.Context) { + builder := mrcb.builders[i] + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*MemoRelationMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, mrcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, mrcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, mrcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (mrcb *MemoRelationCreateBulk) SaveX(ctx context.Context) []*MemoRelation { + v, err := mrcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (mrcb *MemoRelationCreateBulk) Exec(ctx context.Context) error { + _, err := mrcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mrcb *MemoRelationCreateBulk) ExecX(ctx context.Context) { + if err := mrcb.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/memorelation_delete.go b/ent/memorelation_delete.go new file mode 100644 index 0000000000000..31e8c97c5acb5 --- /dev/null +++ b/ent/memorelation_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memorelation" + "github.com/usememos/memos/ent/predicate" +) + +// MemoRelationDelete is the builder for deleting a MemoRelation entity. +type MemoRelationDelete struct { + config + hooks []Hook + mutation *MemoRelationMutation +} + +// Where appends a list predicates to the MemoRelationDelete builder. +func (mrd *MemoRelationDelete) Where(ps ...predicate.MemoRelation) *MemoRelationDelete { + mrd.mutation.Where(ps...) + return mrd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (mrd *MemoRelationDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, mrd.sqlExec, mrd.mutation, mrd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (mrd *MemoRelationDelete) ExecX(ctx context.Context) int { + n, err := mrd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (mrd *MemoRelationDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(memorelation.Table, sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt)) + if ps := mrd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, mrd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + mrd.mutation.done = true + return affected, err +} + +// MemoRelationDeleteOne is the builder for deleting a single MemoRelation entity. +type MemoRelationDeleteOne struct { + mrd *MemoRelationDelete +} + +// Where appends a list predicates to the MemoRelationDelete builder. +func (mrdo *MemoRelationDeleteOne) Where(ps ...predicate.MemoRelation) *MemoRelationDeleteOne { + mrdo.mrd.mutation.Where(ps...) + return mrdo +} + +// Exec executes the deletion query. +func (mrdo *MemoRelationDeleteOne) Exec(ctx context.Context) error { + n, err := mrdo.mrd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{memorelation.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (mrdo *MemoRelationDeleteOne) ExecX(ctx context.Context) { + if err := mrdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/memorelation_query.go b/ent/memorelation_query.go new file mode 100644 index 0000000000000..bdd334f985aa8 --- /dev/null +++ b/ent/memorelation_query.go @@ -0,0 +1,679 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" + "github.com/usememos/memos/ent/predicate" +) + +// MemoRelationQuery is the builder for querying MemoRelation entities. +type MemoRelationQuery struct { + config + ctx *QueryContext + order []memorelation.OrderOption + inters []Interceptor + predicates []predicate.MemoRelation + withMemo *MemoQuery + withRelatedMemo *MemoQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the MemoRelationQuery builder. +func (mrq *MemoRelationQuery) Where(ps ...predicate.MemoRelation) *MemoRelationQuery { + mrq.predicates = append(mrq.predicates, ps...) + return mrq +} + +// Limit the number of records to be returned by this query. +func (mrq *MemoRelationQuery) Limit(limit int) *MemoRelationQuery { + mrq.ctx.Limit = &limit + return mrq +} + +// Offset to start from. +func (mrq *MemoRelationQuery) Offset(offset int) *MemoRelationQuery { + mrq.ctx.Offset = &offset + return mrq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (mrq *MemoRelationQuery) Unique(unique bool) *MemoRelationQuery { + mrq.ctx.Unique = &unique + return mrq +} + +// Order specifies how the records should be ordered. +func (mrq *MemoRelationQuery) Order(o ...memorelation.OrderOption) *MemoRelationQuery { + mrq.order = append(mrq.order, o...) + return mrq +} + +// QueryMemo chains the current query on the "memo" edge. +func (mrq *MemoRelationQuery) QueryMemo() *MemoQuery { + query := (&MemoClient{config: mrq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := mrq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := mrq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(memorelation.Table, memorelation.FieldID, selector), + sqlgraph.To(memo.Table, memo.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, memorelation.MemoTable, memorelation.MemoColumn), + ) + fromU = sqlgraph.SetNeighbors(mrq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryRelatedMemo chains the current query on the "related_memo" edge. +func (mrq *MemoRelationQuery) QueryRelatedMemo() *MemoQuery { + query := (&MemoClient{config: mrq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := mrq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := mrq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(memorelation.Table, memorelation.FieldID, selector), + sqlgraph.To(memo.Table, memo.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, memorelation.RelatedMemoTable, memorelation.RelatedMemoColumn), + ) + fromU = sqlgraph.SetNeighbors(mrq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first MemoRelation entity from the query. +// Returns a *NotFoundError when no MemoRelation was found. +func (mrq *MemoRelationQuery) First(ctx context.Context) (*MemoRelation, error) { + nodes, err := mrq.Limit(1).All(setContextOp(ctx, mrq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{memorelation.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (mrq *MemoRelationQuery) FirstX(ctx context.Context) *MemoRelation { + node, err := mrq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first MemoRelation ID from the query. +// Returns a *NotFoundError when no MemoRelation ID was found. +func (mrq *MemoRelationQuery) FirstID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mrq.Limit(1).IDs(setContextOp(ctx, mrq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{memorelation.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (mrq *MemoRelationQuery) FirstIDX(ctx context.Context) int { + id, err := mrq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single MemoRelation entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one MemoRelation entity is found. +// Returns a *NotFoundError when no MemoRelation entities are found. +func (mrq *MemoRelationQuery) Only(ctx context.Context) (*MemoRelation, error) { + nodes, err := mrq.Limit(2).All(setContextOp(ctx, mrq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{memorelation.Label} + default: + return nil, &NotSingularError{memorelation.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (mrq *MemoRelationQuery) OnlyX(ctx context.Context) *MemoRelation { + node, err := mrq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only MemoRelation ID in the query. +// Returns a *NotSingularError when more than one MemoRelation ID is found. +// Returns a *NotFoundError when no entities are found. +func (mrq *MemoRelationQuery) OnlyID(ctx context.Context) (id int, err error) { + var ids []int + if ids, err = mrq.Limit(2).IDs(setContextOp(ctx, mrq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{memorelation.Label} + default: + err = &NotSingularError{memorelation.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (mrq *MemoRelationQuery) OnlyIDX(ctx context.Context) int { + id, err := mrq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of MemoRelations. +func (mrq *MemoRelationQuery) All(ctx context.Context) ([]*MemoRelation, error) { + ctx = setContextOp(ctx, mrq.ctx, "All") + if err := mrq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*MemoRelation, *MemoRelationQuery]() + return withInterceptors[[]*MemoRelation](ctx, mrq, qr, mrq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (mrq *MemoRelationQuery) AllX(ctx context.Context) []*MemoRelation { + nodes, err := mrq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of MemoRelation IDs. +func (mrq *MemoRelationQuery) IDs(ctx context.Context) (ids []int, err error) { + if mrq.ctx.Unique == nil && mrq.path != nil { + mrq.Unique(true) + } + ctx = setContextOp(ctx, mrq.ctx, "IDs") + if err = mrq.Select(memorelation.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (mrq *MemoRelationQuery) IDsX(ctx context.Context) []int { + ids, err := mrq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (mrq *MemoRelationQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, mrq.ctx, "Count") + if err := mrq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, mrq, querierCount[*MemoRelationQuery](), mrq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (mrq *MemoRelationQuery) CountX(ctx context.Context) int { + count, err := mrq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (mrq *MemoRelationQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, mrq.ctx, "Exist") + switch _, err := mrq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (mrq *MemoRelationQuery) ExistX(ctx context.Context) bool { + exist, err := mrq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the MemoRelationQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (mrq *MemoRelationQuery) Clone() *MemoRelationQuery { + if mrq == nil { + return nil + } + return &MemoRelationQuery{ + config: mrq.config, + ctx: mrq.ctx.Clone(), + order: append([]memorelation.OrderOption{}, mrq.order...), + inters: append([]Interceptor{}, mrq.inters...), + predicates: append([]predicate.MemoRelation{}, mrq.predicates...), + withMemo: mrq.withMemo.Clone(), + withRelatedMemo: mrq.withRelatedMemo.Clone(), + // clone intermediate query. + sql: mrq.sql.Clone(), + path: mrq.path, + } +} + +// WithMemo tells the query-builder to eager-load the nodes that are connected to +// the "memo" edge. The optional arguments are used to configure the query builder of the edge. +func (mrq *MemoRelationQuery) WithMemo(opts ...func(*MemoQuery)) *MemoRelationQuery { + query := (&MemoClient{config: mrq.config}).Query() + for _, opt := range opts { + opt(query) + } + mrq.withMemo = query + return mrq +} + +// WithRelatedMemo tells the query-builder to eager-load the nodes that are connected to +// the "related_memo" edge. The optional arguments are used to configure the query builder of the edge. +func (mrq *MemoRelationQuery) WithRelatedMemo(opts ...func(*MemoQuery)) *MemoRelationQuery { + query := (&MemoClient{config: mrq.config}).Query() + for _, opt := range opts { + opt(query) + } + mrq.withRelatedMemo = query + return mrq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// Type string `json:"type,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.MemoRelation.Query(). +// GroupBy(memorelation.FieldType). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (mrq *MemoRelationQuery) GroupBy(field string, fields ...string) *MemoRelationGroupBy { + mrq.ctx.Fields = append([]string{field}, fields...) + grbuild := &MemoRelationGroupBy{build: mrq} + grbuild.flds = &mrq.ctx.Fields + grbuild.label = memorelation.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// Type string `json:"type,omitempty"` +// } +// +// client.MemoRelation.Query(). +// Select(memorelation.FieldType). +// Scan(ctx, &v) +func (mrq *MemoRelationQuery) Select(fields ...string) *MemoRelationSelect { + mrq.ctx.Fields = append(mrq.ctx.Fields, fields...) + sbuild := &MemoRelationSelect{MemoRelationQuery: mrq} + sbuild.label = memorelation.Label + sbuild.flds, sbuild.scan = &mrq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a MemoRelationSelect configured with the given aggregations. +func (mrq *MemoRelationQuery) Aggregate(fns ...AggregateFunc) *MemoRelationSelect { + return mrq.Select().Aggregate(fns...) +} + +func (mrq *MemoRelationQuery) prepareQuery(ctx context.Context) error { + for _, inter := range mrq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, mrq); err != nil { + return err + } + } + } + for _, f := range mrq.ctx.Fields { + if !memorelation.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if mrq.path != nil { + prev, err := mrq.path(ctx) + if err != nil { + return err + } + mrq.sql = prev + } + return nil +} + +func (mrq *MemoRelationQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*MemoRelation, error) { + var ( + nodes = []*MemoRelation{} + _spec = mrq.querySpec() + loadedTypes = [2]bool{ + mrq.withMemo != nil, + mrq.withRelatedMemo != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*MemoRelation).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &MemoRelation{config: mrq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, mrq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := mrq.withMemo; query != nil { + if err := mrq.loadMemo(ctx, query, nodes, nil, + func(n *MemoRelation, e *Memo) { n.Edges.Memo = e }); err != nil { + return nil, err + } + } + if query := mrq.withRelatedMemo; query != nil { + if err := mrq.loadRelatedMemo(ctx, query, nodes, nil, + func(n *MemoRelation, e *Memo) { n.Edges.RelatedMemo = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (mrq *MemoRelationQuery) loadMemo(ctx context.Context, query *MemoQuery, nodes []*MemoRelation, init func(*MemoRelation), assign func(*MemoRelation, *Memo)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*MemoRelation) + for i := range nodes { + fk := nodes[i].MemoID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(memo.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "memo_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (mrq *MemoRelationQuery) loadRelatedMemo(ctx context.Context, query *MemoQuery, nodes []*MemoRelation, init func(*MemoRelation), assign func(*MemoRelation, *Memo)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*MemoRelation) + for i := range nodes { + fk := nodes[i].RelatedMemoID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(memo.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "related_memo_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (mrq *MemoRelationQuery) sqlCount(ctx context.Context) (int, error) { + _spec := mrq.querySpec() + _spec.Node.Columns = mrq.ctx.Fields + if len(mrq.ctx.Fields) > 0 { + _spec.Unique = mrq.ctx.Unique != nil && *mrq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, mrq.driver, _spec) +} + +func (mrq *MemoRelationQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(memorelation.Table, memorelation.Columns, sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt)) + _spec.From = mrq.sql + if unique := mrq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if mrq.path != nil { + _spec.Unique = true + } + if fields := mrq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, memorelation.FieldID) + for i := range fields { + if fields[i] != memorelation.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if mrq.withMemo != nil { + _spec.Node.AddColumnOnce(memorelation.FieldMemoID) + } + if mrq.withRelatedMemo != nil { + _spec.Node.AddColumnOnce(memorelation.FieldRelatedMemoID) + } + } + if ps := mrq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := mrq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := mrq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := mrq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (mrq *MemoRelationQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(mrq.driver.Dialect()) + t1 := builder.Table(memorelation.Table) + columns := mrq.ctx.Fields + if len(columns) == 0 { + columns = memorelation.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if mrq.sql != nil { + selector = mrq.sql + selector.Select(selector.Columns(columns...)...) + } + if mrq.ctx.Unique != nil && *mrq.ctx.Unique { + selector.Distinct() + } + for _, p := range mrq.predicates { + p(selector) + } + for _, p := range mrq.order { + p(selector) + } + if offset := mrq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := mrq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// MemoRelationGroupBy is the group-by builder for MemoRelation entities. +type MemoRelationGroupBy struct { + selector + build *MemoRelationQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (mrgb *MemoRelationGroupBy) Aggregate(fns ...AggregateFunc) *MemoRelationGroupBy { + mrgb.fns = append(mrgb.fns, fns...) + return mrgb +} + +// Scan applies the selector query and scans the result into the given value. +func (mrgb *MemoRelationGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mrgb.build.ctx, "GroupBy") + if err := mrgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MemoRelationQuery, *MemoRelationGroupBy](ctx, mrgb.build, mrgb, mrgb.build.inters, v) +} + +func (mrgb *MemoRelationGroupBy) sqlScan(ctx context.Context, root *MemoRelationQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(mrgb.fns)) + for _, fn := range mrgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*mrgb.flds)+len(mrgb.fns)) + for _, f := range *mrgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*mrgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := mrgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// MemoRelationSelect is the builder for selecting fields of MemoRelation entities. +type MemoRelationSelect struct { + *MemoRelationQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (mrs *MemoRelationSelect) Aggregate(fns ...AggregateFunc) *MemoRelationSelect { + mrs.fns = append(mrs.fns, fns...) + return mrs +} + +// Scan applies the selector query and scans the result into the given value. +func (mrs *MemoRelationSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, mrs.ctx, "Select") + if err := mrs.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*MemoRelationQuery, *MemoRelationSelect](ctx, mrs.MemoRelationQuery, mrs, mrs.inters, v) +} + +func (mrs *MemoRelationSelect) sqlScan(ctx context.Context, root *MemoRelationQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(mrs.fns)) + for _, fn := range mrs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*mrs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := mrs.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/ent/memorelation_update.go b/ent/memorelation_update.go new file mode 100644 index 0000000000000..bf3cc93907934 --- /dev/null +++ b/ent/memorelation_update.go @@ -0,0 +1,454 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" + "github.com/usememos/memos/ent/predicate" +) + +// MemoRelationUpdate is the builder for updating MemoRelation entities. +type MemoRelationUpdate struct { + config + hooks []Hook + mutation *MemoRelationMutation +} + +// Where appends a list predicates to the MemoRelationUpdate builder. +func (mru *MemoRelationUpdate) Where(ps ...predicate.MemoRelation) *MemoRelationUpdate { + mru.mutation.Where(ps...) + return mru +} + +// SetType sets the "type" field. +func (mru *MemoRelationUpdate) SetType(s string) *MemoRelationUpdate { + mru.mutation.SetType(s) + return mru +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (mru *MemoRelationUpdate) SetNillableType(s *string) *MemoRelationUpdate { + if s != nil { + mru.SetType(*s) + } + return mru +} + +// SetMemoID sets the "memo_id" field. +func (mru *MemoRelationUpdate) SetMemoID(i int) *MemoRelationUpdate { + mru.mutation.SetMemoID(i) + return mru +} + +// SetNillableMemoID sets the "memo_id" field if the given value is not nil. +func (mru *MemoRelationUpdate) SetNillableMemoID(i *int) *MemoRelationUpdate { + if i != nil { + mru.SetMemoID(*i) + } + return mru +} + +// SetRelatedMemoID sets the "related_memo_id" field. +func (mru *MemoRelationUpdate) SetRelatedMemoID(i int) *MemoRelationUpdate { + mru.mutation.SetRelatedMemoID(i) + return mru +} + +// SetNillableRelatedMemoID sets the "related_memo_id" field if the given value is not nil. +func (mru *MemoRelationUpdate) SetNillableRelatedMemoID(i *int) *MemoRelationUpdate { + if i != nil { + mru.SetRelatedMemoID(*i) + } + return mru +} + +// SetMemo sets the "memo" edge to the Memo entity. +func (mru *MemoRelationUpdate) SetMemo(m *Memo) *MemoRelationUpdate { + return mru.SetMemoID(m.ID) +} + +// SetRelatedMemo sets the "related_memo" edge to the Memo entity. +func (mru *MemoRelationUpdate) SetRelatedMemo(m *Memo) *MemoRelationUpdate { + return mru.SetRelatedMemoID(m.ID) +} + +// Mutation returns the MemoRelationMutation object of the builder. +func (mru *MemoRelationUpdate) Mutation() *MemoRelationMutation { + return mru.mutation +} + +// ClearMemo clears the "memo" edge to the Memo entity. +func (mru *MemoRelationUpdate) ClearMemo() *MemoRelationUpdate { + mru.mutation.ClearMemo() + return mru +} + +// ClearRelatedMemo clears the "related_memo" edge to the Memo entity. +func (mru *MemoRelationUpdate) ClearRelatedMemo() *MemoRelationUpdate { + mru.mutation.ClearRelatedMemo() + return mru +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (mru *MemoRelationUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, mru.sqlSave, mru.mutation, mru.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (mru *MemoRelationUpdate) SaveX(ctx context.Context) int { + affected, err := mru.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (mru *MemoRelationUpdate) Exec(ctx context.Context) error { + _, err := mru.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mru *MemoRelationUpdate) ExecX(ctx context.Context) { + if err := mru.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mru *MemoRelationUpdate) check() error { + if _, ok := mru.mutation.MemoID(); mru.mutation.MemoCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "MemoRelation.memo"`) + } + if _, ok := mru.mutation.RelatedMemoID(); mru.mutation.RelatedMemoCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "MemoRelation.related_memo"`) + } + return nil +} + +func (mru *MemoRelationUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := mru.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(memorelation.Table, memorelation.Columns, sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt)) + if ps := mru.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := mru.mutation.GetType(); ok { + _spec.SetField(memorelation.FieldType, field.TypeString, value) + } + if mru.mutation.MemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.MemoTable, + Columns: []string{memorelation.MemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mru.mutation.MemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.MemoTable, + Columns: []string{memorelation.MemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if mru.mutation.RelatedMemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.RelatedMemoTable, + Columns: []string{memorelation.RelatedMemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mru.mutation.RelatedMemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.RelatedMemoTable, + Columns: []string{memorelation.RelatedMemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if n, err = sqlgraph.UpdateNodes(ctx, mru.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{memorelation.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + mru.mutation.done = true + return n, nil +} + +// MemoRelationUpdateOne is the builder for updating a single MemoRelation entity. +type MemoRelationUpdateOne struct { + config + fields []string + hooks []Hook + mutation *MemoRelationMutation +} + +// SetType sets the "type" field. +func (mruo *MemoRelationUpdateOne) SetType(s string) *MemoRelationUpdateOne { + mruo.mutation.SetType(s) + return mruo +} + +// SetNillableType sets the "type" field if the given value is not nil. +func (mruo *MemoRelationUpdateOne) SetNillableType(s *string) *MemoRelationUpdateOne { + if s != nil { + mruo.SetType(*s) + } + return mruo +} + +// SetMemoID sets the "memo_id" field. +func (mruo *MemoRelationUpdateOne) SetMemoID(i int) *MemoRelationUpdateOne { + mruo.mutation.SetMemoID(i) + return mruo +} + +// SetNillableMemoID sets the "memo_id" field if the given value is not nil. +func (mruo *MemoRelationUpdateOne) SetNillableMemoID(i *int) *MemoRelationUpdateOne { + if i != nil { + mruo.SetMemoID(*i) + } + return mruo +} + +// SetRelatedMemoID sets the "related_memo_id" field. +func (mruo *MemoRelationUpdateOne) SetRelatedMemoID(i int) *MemoRelationUpdateOne { + mruo.mutation.SetRelatedMemoID(i) + return mruo +} + +// SetNillableRelatedMemoID sets the "related_memo_id" field if the given value is not nil. +func (mruo *MemoRelationUpdateOne) SetNillableRelatedMemoID(i *int) *MemoRelationUpdateOne { + if i != nil { + mruo.SetRelatedMemoID(*i) + } + return mruo +} + +// SetMemo sets the "memo" edge to the Memo entity. +func (mruo *MemoRelationUpdateOne) SetMemo(m *Memo) *MemoRelationUpdateOne { + return mruo.SetMemoID(m.ID) +} + +// SetRelatedMemo sets the "related_memo" edge to the Memo entity. +func (mruo *MemoRelationUpdateOne) SetRelatedMemo(m *Memo) *MemoRelationUpdateOne { + return mruo.SetRelatedMemoID(m.ID) +} + +// Mutation returns the MemoRelationMutation object of the builder. +func (mruo *MemoRelationUpdateOne) Mutation() *MemoRelationMutation { + return mruo.mutation +} + +// ClearMemo clears the "memo" edge to the Memo entity. +func (mruo *MemoRelationUpdateOne) ClearMemo() *MemoRelationUpdateOne { + mruo.mutation.ClearMemo() + return mruo +} + +// ClearRelatedMemo clears the "related_memo" edge to the Memo entity. +func (mruo *MemoRelationUpdateOne) ClearRelatedMemo() *MemoRelationUpdateOne { + mruo.mutation.ClearRelatedMemo() + return mruo +} + +// Where appends a list predicates to the MemoRelationUpdate builder. +func (mruo *MemoRelationUpdateOne) Where(ps ...predicate.MemoRelation) *MemoRelationUpdateOne { + mruo.mutation.Where(ps...) + return mruo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (mruo *MemoRelationUpdateOne) Select(field string, fields ...string) *MemoRelationUpdateOne { + mruo.fields = append([]string{field}, fields...) + return mruo +} + +// Save executes the query and returns the updated MemoRelation entity. +func (mruo *MemoRelationUpdateOne) Save(ctx context.Context) (*MemoRelation, error) { + return withHooks(ctx, mruo.sqlSave, mruo.mutation, mruo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (mruo *MemoRelationUpdateOne) SaveX(ctx context.Context) *MemoRelation { + node, err := mruo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (mruo *MemoRelationUpdateOne) Exec(ctx context.Context) error { + _, err := mruo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (mruo *MemoRelationUpdateOne) ExecX(ctx context.Context) { + if err := mruo.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (mruo *MemoRelationUpdateOne) check() error { + if _, ok := mruo.mutation.MemoID(); mruo.mutation.MemoCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "MemoRelation.memo"`) + } + if _, ok := mruo.mutation.RelatedMemoID(); mruo.mutation.RelatedMemoCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "MemoRelation.related_memo"`) + } + return nil +} + +func (mruo *MemoRelationUpdateOne) sqlSave(ctx context.Context) (_node *MemoRelation, err error) { + if err := mruo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(memorelation.Table, memorelation.Columns, sqlgraph.NewFieldSpec(memorelation.FieldID, field.TypeInt)) + id, ok := mruo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "MemoRelation.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := mruo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, memorelation.FieldID) + for _, f := range fields { + if !memorelation.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != memorelation.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := mruo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := mruo.mutation.GetType(); ok { + _spec.SetField(memorelation.FieldType, field.TypeString, value) + } + if mruo.mutation.MemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.MemoTable, + Columns: []string{memorelation.MemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mruo.mutation.MemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.MemoTable, + Columns: []string{memorelation.MemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if mruo.mutation.RelatedMemoCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.RelatedMemoTable, + Columns: []string{memorelation.RelatedMemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := mruo.mutation.RelatedMemoIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: memorelation.RelatedMemoTable, + Columns: []string{memorelation.RelatedMemoColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(memo.FieldID, field.TypeInt), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &MemoRelation{config: mruo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, mruo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{memorelation.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + mruo.mutation.done = true + return _node, nil +} diff --git a/ent/migrate/migrate.go b/ent/migrate/migrate.go new file mode 100644 index 0000000000000..1956a6bf6437c --- /dev/null +++ b/ent/migrate/migrate.go @@ -0,0 +1,64 @@ +// Code generated by ent, DO NOT EDIT. + +package migrate + +import ( + "context" + "fmt" + "io" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql/schema" +) + +var ( + // WithGlobalUniqueID sets the universal ids options to the migration. + // If this option is enabled, ent migration will allocate a 1<<32 range + // for the ids of each entity (table). + // Note that this option cannot be applied on tables that already exist. + WithGlobalUniqueID = schema.WithGlobalUniqueID + // WithDropColumn sets the drop column option to the migration. + // If this option is enabled, ent migration will drop old columns + // that were used for both fields and edges. This defaults to false. + WithDropColumn = schema.WithDropColumn + // WithDropIndex sets the drop index option to the migration. + // If this option is enabled, ent migration will drop old indexes + // that were defined in the schema. This defaults to false. + // Note that unique constraints are defined using `UNIQUE INDEX`, + // and therefore, it's recommended to enable this option to get more + // flexibility in the schema changes. + WithDropIndex = schema.WithDropIndex + // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. + WithForeignKeys = schema.WithForeignKeys +) + +// Schema is the API for creating, migrating and dropping a schema. +type Schema struct { + drv dialect.Driver +} + +// NewSchema creates a new schema client. +func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } + +// Create creates all schema resources. +func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { + return Create(ctx, s, Tables, opts...) +} + +// Create creates all table resources using the given schema driver. +func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...schema.MigrateOption) error { + migrate, err := schema.NewMigrate(s.drv, opts...) + if err != nil { + return fmt.Errorf("ent/migrate: %w", err) + } + return migrate.Create(ctx, tables...) +} + +// WriteTo writes the schema changes to w instead of running them against the database. +// +// if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { +// log.Fatal(err) +// } +func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { + return Create(ctx, &Schema{drv: &schema.WriteDriver{Writer: w, Driver: s.drv}}, Tables, opts...) +} diff --git a/ent/migrate/schema.go b/ent/migrate/schema.go new file mode 100644 index 0000000000000..e509a518b1dbb --- /dev/null +++ b/ent/migrate/schema.go @@ -0,0 +1,72 @@ +// Code generated by ent, DO NOT EDIT. + +package migrate + +import ( + "entgo.io/ent/dialect/sql/schema" + "entgo.io/ent/schema/field" +) + +var ( + // MemosColumns holds the columns for the "memos" table. + MemosColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "resource_name", Type: field.TypeString, Unique: true, Size: 256}, + {Name: "creator_id", Type: field.TypeInt}, + {Name: "created_ts", Type: field.TypeTime}, + {Name: "updated_ts", Type: field.TypeTime}, + {Name: "row_status", Type: field.TypeString, Size: 256}, + {Name: "content", Type: field.TypeString, Size: 2147483647, Default: ""}, + {Name: "visibility", Type: field.TypeString, Size: 256}, + } + // MemosTable holds the schema information for the "memos" table. + MemosTable = &schema.Table{ + Name: "memos", + Columns: MemosColumns, + PrimaryKey: []*schema.Column{MemosColumns[0]}, + } + // MemoRelationsColumns holds the columns for the "memo_relations" table. + MemoRelationsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt, Increment: true}, + {Name: "type", Type: field.TypeString}, + {Name: "memo_id", Type: field.TypeInt}, + {Name: "related_memo_id", Type: field.TypeInt}, + } + // MemoRelationsTable holds the schema information for the "memo_relations" table. + MemoRelationsTable = &schema.Table{ + Name: "memo_relations", + Columns: MemoRelationsColumns, + PrimaryKey: []*schema.Column{MemoRelationsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "memo_relations_memos_memo", + Columns: []*schema.Column{MemoRelationsColumns[2]}, + RefColumns: []*schema.Column{MemosColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "memo_relations_memos_related_memo", + Columns: []*schema.Column{MemoRelationsColumns[3]}, + RefColumns: []*schema.Column{MemosColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "memorelation_memo_id_related_memo_id", + Unique: true, + Columns: []*schema.Column{MemoRelationsColumns[2], MemoRelationsColumns[3]}, + }, + }, + } + // Tables holds all the tables in the schema. + Tables = []*schema.Table{ + MemosTable, + MemoRelationsTable, + } +) + +func init() { + MemoRelationsTable.ForeignKeys[0].RefTable = MemosTable + MemoRelationsTable.ForeignKeys[1].RefTable = MemosTable +} diff --git a/ent/mutation.go b/ent/mutation.go new file mode 100644 index 0000000000000..6dce99c11e0c7 --- /dev/null +++ b/ent/mutation.go @@ -0,0 +1,1435 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/memorelation" + "github.com/usememos/memos/ent/predicate" +) + +const ( + // Operation types. + OpCreate = ent.OpCreate + OpDelete = ent.OpDelete + OpDeleteOne = ent.OpDeleteOne + OpUpdate = ent.OpUpdate + OpUpdateOne = ent.OpUpdateOne + + // Node types. + TypeMemo = "Memo" + TypeMemoRelation = "MemoRelation" +) + +// MemoMutation represents an operation that mutates the Memo nodes in the graph. +type MemoMutation struct { + config + op Op + typ string + id *int + resource_name *string + creator_id *int + addcreator_id *int + created_ts *time.Time + updated_ts *time.Time + row_status *string + content *string + visibility *string + clearedFields map[string]struct{} + related_memo map[int]struct{} + removedrelated_memo map[int]struct{} + clearedrelated_memo bool + memo_relation map[int]struct{} + removedmemo_relation map[int]struct{} + clearedmemo_relation bool + done bool + oldValue func(context.Context) (*Memo, error) + predicates []predicate.Memo +} + +var _ ent.Mutation = (*MemoMutation)(nil) + +// memoOption allows management of the mutation configuration using functional options. +type memoOption func(*MemoMutation) + +// newMemoMutation creates new mutation for the Memo entity. +func newMemoMutation(c config, op Op, opts ...memoOption) *MemoMutation { + m := &MemoMutation{ + config: c, + op: op, + typ: TypeMemo, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withMemoID sets the ID field of the mutation. +func withMemoID(id int) memoOption { + return func(m *MemoMutation) { + var ( + err error + once sync.Once + value *Memo + ) + m.oldValue = func(ctx context.Context) (*Memo, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Memo.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withMemo sets the old Memo of the mutation. +func withMemo(node *Memo) memoOption { + return func(m *MemoMutation) { + m.oldValue = func(context.Context) (*Memo, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MemoMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MemoMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of Memo entities. +func (m *MemoMutation) SetID(id int) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MemoMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MemoMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Memo.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetResourceName sets the "resource_name" field. +func (m *MemoMutation) SetResourceName(s string) { + m.resource_name = &s +} + +// ResourceName returns the value of the "resource_name" field in the mutation. +func (m *MemoMutation) ResourceName() (r string, exists bool) { + v := m.resource_name + if v == nil { + return + } + return *v, true +} + +// OldResourceName returns the old "resource_name" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldResourceName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResourceName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResourceName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResourceName: %w", err) + } + return oldValue.ResourceName, nil +} + +// ResetResourceName resets all changes to the "resource_name" field. +func (m *MemoMutation) ResetResourceName() { + m.resource_name = nil +} + +// SetCreatorID sets the "creator_id" field. +func (m *MemoMutation) SetCreatorID(i int) { + m.creator_id = &i + m.addcreator_id = nil +} + +// CreatorID returns the value of the "creator_id" field in the mutation. +func (m *MemoMutation) CreatorID() (r int, exists bool) { + v := m.creator_id + if v == nil { + return + } + return *v, true +} + +// OldCreatorID returns the old "creator_id" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldCreatorID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatorID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatorID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatorID: %w", err) + } + return oldValue.CreatorID, nil +} + +// AddCreatorID adds i to the "creator_id" field. +func (m *MemoMutation) AddCreatorID(i int) { + if m.addcreator_id != nil { + *m.addcreator_id += i + } else { + m.addcreator_id = &i + } +} + +// AddedCreatorID returns the value that was added to the "creator_id" field in this mutation. +func (m *MemoMutation) AddedCreatorID() (r int, exists bool) { + v := m.addcreator_id + if v == nil { + return + } + return *v, true +} + +// ResetCreatorID resets all changes to the "creator_id" field. +func (m *MemoMutation) ResetCreatorID() { + m.creator_id = nil + m.addcreator_id = nil +} + +// SetCreatedTs sets the "created_ts" field. +func (m *MemoMutation) SetCreatedTs(t time.Time) { + m.created_ts = &t +} + +// CreatedTs returns the value of the "created_ts" field in the mutation. +func (m *MemoMutation) CreatedTs() (r time.Time, exists bool) { + v := m.created_ts + if v == nil { + return + } + return *v, true +} + +// OldCreatedTs returns the old "created_ts" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldCreatedTs(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedTs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedTs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedTs: %w", err) + } + return oldValue.CreatedTs, nil +} + +// ResetCreatedTs resets all changes to the "created_ts" field. +func (m *MemoMutation) ResetCreatedTs() { + m.created_ts = nil +} + +// SetUpdatedTs sets the "updated_ts" field. +func (m *MemoMutation) SetUpdatedTs(t time.Time) { + m.updated_ts = &t +} + +// UpdatedTs returns the value of the "updated_ts" field in the mutation. +func (m *MemoMutation) UpdatedTs() (r time.Time, exists bool) { + v := m.updated_ts + if v == nil { + return + } + return *v, true +} + +// OldUpdatedTs returns the old "updated_ts" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldUpdatedTs(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedTs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedTs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedTs: %w", err) + } + return oldValue.UpdatedTs, nil +} + +// ResetUpdatedTs resets all changes to the "updated_ts" field. +func (m *MemoMutation) ResetUpdatedTs() { + m.updated_ts = nil +} + +// SetRowStatus sets the "row_status" field. +func (m *MemoMutation) SetRowStatus(s string) { + m.row_status = &s +} + +// RowStatus returns the value of the "row_status" field in the mutation. +func (m *MemoMutation) RowStatus() (r string, exists bool) { + v := m.row_status + if v == nil { + return + } + return *v, true +} + +// OldRowStatus returns the old "row_status" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldRowStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRowStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRowStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRowStatus: %w", err) + } + return oldValue.RowStatus, nil +} + +// ResetRowStatus resets all changes to the "row_status" field. +func (m *MemoMutation) ResetRowStatus() { + m.row_status = nil +} + +// SetContent sets the "content" field. +func (m *MemoMutation) SetContent(s string) { + m.content = &s +} + +// Content returns the value of the "content" field in the mutation. +func (m *MemoMutation) Content() (r string, exists bool) { + v := m.content + if v == nil { + return + } + return *v, true +} + +// OldContent returns the old "content" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldContent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldContent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldContent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldContent: %w", err) + } + return oldValue.Content, nil +} + +// ResetContent resets all changes to the "content" field. +func (m *MemoMutation) ResetContent() { + m.content = nil +} + +// SetVisibility sets the "visibility" field. +func (m *MemoMutation) SetVisibility(s string) { + m.visibility = &s +} + +// Visibility returns the value of the "visibility" field in the mutation. +func (m *MemoMutation) Visibility() (r string, exists bool) { + v := m.visibility + if v == nil { + return + } + return *v, true +} + +// OldVisibility returns the old "visibility" field's value of the Memo entity. +// If the Memo object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoMutation) OldVisibility(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVisibility is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVisibility requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVisibility: %w", err) + } + return oldValue.Visibility, nil +} + +// ResetVisibility resets all changes to the "visibility" field. +func (m *MemoMutation) ResetVisibility() { + m.visibility = nil +} + +// AddRelatedMemoIDs adds the "related_memo" edge to the Memo entity by ids. +func (m *MemoMutation) AddRelatedMemoIDs(ids ...int) { + if m.related_memo == nil { + m.related_memo = make(map[int]struct{}) + } + for i := range ids { + m.related_memo[ids[i]] = struct{}{} + } +} + +// ClearRelatedMemo clears the "related_memo" edge to the Memo entity. +func (m *MemoMutation) ClearRelatedMemo() { + m.clearedrelated_memo = true +} + +// RelatedMemoCleared reports if the "related_memo" edge to the Memo entity was cleared. +func (m *MemoMutation) RelatedMemoCleared() bool { + return m.clearedrelated_memo +} + +// RemoveRelatedMemoIDs removes the "related_memo" edge to the Memo entity by IDs. +func (m *MemoMutation) RemoveRelatedMemoIDs(ids ...int) { + if m.removedrelated_memo == nil { + m.removedrelated_memo = make(map[int]struct{}) + } + for i := range ids { + delete(m.related_memo, ids[i]) + m.removedrelated_memo[ids[i]] = struct{}{} + } +} + +// RemovedRelatedMemo returns the removed IDs of the "related_memo" edge to the Memo entity. +func (m *MemoMutation) RemovedRelatedMemoIDs() (ids []int) { + for id := range m.removedrelated_memo { + ids = append(ids, id) + } + return +} + +// RelatedMemoIDs returns the "related_memo" edge IDs in the mutation. +func (m *MemoMutation) RelatedMemoIDs() (ids []int) { + for id := range m.related_memo { + ids = append(ids, id) + } + return +} + +// ResetRelatedMemo resets all changes to the "related_memo" edge. +func (m *MemoMutation) ResetRelatedMemo() { + m.related_memo = nil + m.clearedrelated_memo = false + m.removedrelated_memo = nil +} + +// AddMemoRelationIDs adds the "memo_relation" edge to the MemoRelation entity by ids. +func (m *MemoMutation) AddMemoRelationIDs(ids ...int) { + if m.memo_relation == nil { + m.memo_relation = make(map[int]struct{}) + } + for i := range ids { + m.memo_relation[ids[i]] = struct{}{} + } +} + +// ClearMemoRelation clears the "memo_relation" edge to the MemoRelation entity. +func (m *MemoMutation) ClearMemoRelation() { + m.clearedmemo_relation = true +} + +// MemoRelationCleared reports if the "memo_relation" edge to the MemoRelation entity was cleared. +func (m *MemoMutation) MemoRelationCleared() bool { + return m.clearedmemo_relation +} + +// RemoveMemoRelationIDs removes the "memo_relation" edge to the MemoRelation entity by IDs. +func (m *MemoMutation) RemoveMemoRelationIDs(ids ...int) { + if m.removedmemo_relation == nil { + m.removedmemo_relation = make(map[int]struct{}) + } + for i := range ids { + delete(m.memo_relation, ids[i]) + m.removedmemo_relation[ids[i]] = struct{}{} + } +} + +// RemovedMemoRelation returns the removed IDs of the "memo_relation" edge to the MemoRelation entity. +func (m *MemoMutation) RemovedMemoRelationIDs() (ids []int) { + for id := range m.removedmemo_relation { + ids = append(ids, id) + } + return +} + +// MemoRelationIDs returns the "memo_relation" edge IDs in the mutation. +func (m *MemoMutation) MemoRelationIDs() (ids []int) { + for id := range m.memo_relation { + ids = append(ids, id) + } + return +} + +// ResetMemoRelation resets all changes to the "memo_relation" edge. +func (m *MemoMutation) ResetMemoRelation() { + m.memo_relation = nil + m.clearedmemo_relation = false + m.removedmemo_relation = nil +} + +// Where appends a list predicates to the MemoMutation builder. +func (m *MemoMutation) Where(ps ...predicate.Memo) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the MemoMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MemoMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Memo, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *MemoMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *MemoMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Memo). +func (m *MemoMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *MemoMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.resource_name != nil { + fields = append(fields, memo.FieldResourceName) + } + if m.creator_id != nil { + fields = append(fields, memo.FieldCreatorID) + } + if m.created_ts != nil { + fields = append(fields, memo.FieldCreatedTs) + } + if m.updated_ts != nil { + fields = append(fields, memo.FieldUpdatedTs) + } + if m.row_status != nil { + fields = append(fields, memo.FieldRowStatus) + } + if m.content != nil { + fields = append(fields, memo.FieldContent) + } + if m.visibility != nil { + fields = append(fields, memo.FieldVisibility) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *MemoMutation) Field(name string) (ent.Value, bool) { + switch name { + case memo.FieldResourceName: + return m.ResourceName() + case memo.FieldCreatorID: + return m.CreatorID() + case memo.FieldCreatedTs: + return m.CreatedTs() + case memo.FieldUpdatedTs: + return m.UpdatedTs() + case memo.FieldRowStatus: + return m.RowStatus() + case memo.FieldContent: + return m.Content() + case memo.FieldVisibility: + return m.Visibility() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *MemoMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case memo.FieldResourceName: + return m.OldResourceName(ctx) + case memo.FieldCreatorID: + return m.OldCreatorID(ctx) + case memo.FieldCreatedTs: + return m.OldCreatedTs(ctx) + case memo.FieldUpdatedTs: + return m.OldUpdatedTs(ctx) + case memo.FieldRowStatus: + return m.OldRowStatus(ctx) + case memo.FieldContent: + return m.OldContent(ctx) + case memo.FieldVisibility: + return m.OldVisibility(ctx) + } + return nil, fmt.Errorf("unknown Memo field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MemoMutation) SetField(name string, value ent.Value) error { + switch name { + case memo.FieldResourceName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResourceName(v) + return nil + case memo.FieldCreatorID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatorID(v) + return nil + case memo.FieldCreatedTs: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedTs(v) + return nil + case memo.FieldUpdatedTs: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedTs(v) + return nil + case memo.FieldRowStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRowStatus(v) + return nil + case memo.FieldContent: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetContent(v) + return nil + case memo.FieldVisibility: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVisibility(v) + return nil + } + return fmt.Errorf("unknown Memo field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *MemoMutation) AddedFields() []string { + var fields []string + if m.addcreator_id != nil { + fields = append(fields, memo.FieldCreatorID) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *MemoMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case memo.FieldCreatorID: + return m.AddedCreatorID() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MemoMutation) AddField(name string, value ent.Value) error { + switch name { + case memo.FieldCreatorID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCreatorID(v) + return nil + } + return fmt.Errorf("unknown Memo numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *MemoMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *MemoMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *MemoMutation) ClearField(name string) error { + return fmt.Errorf("unknown Memo nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *MemoMutation) ResetField(name string) error { + switch name { + case memo.FieldResourceName: + m.ResetResourceName() + return nil + case memo.FieldCreatorID: + m.ResetCreatorID() + return nil + case memo.FieldCreatedTs: + m.ResetCreatedTs() + return nil + case memo.FieldUpdatedTs: + m.ResetUpdatedTs() + return nil + case memo.FieldRowStatus: + m.ResetRowStatus() + return nil + case memo.FieldContent: + m.ResetContent() + return nil + case memo.FieldVisibility: + m.ResetVisibility() + return nil + } + return fmt.Errorf("unknown Memo field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *MemoMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.related_memo != nil { + edges = append(edges, memo.EdgeRelatedMemo) + } + if m.memo_relation != nil { + edges = append(edges, memo.EdgeMemoRelation) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *MemoMutation) AddedIDs(name string) []ent.Value { + switch name { + case memo.EdgeRelatedMemo: + ids := make([]ent.Value, 0, len(m.related_memo)) + for id := range m.related_memo { + ids = append(ids, id) + } + return ids + case memo.EdgeMemoRelation: + ids := make([]ent.Value, 0, len(m.memo_relation)) + for id := range m.memo_relation { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *MemoMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + if m.removedrelated_memo != nil { + edges = append(edges, memo.EdgeRelatedMemo) + } + if m.removedmemo_relation != nil { + edges = append(edges, memo.EdgeMemoRelation) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *MemoMutation) RemovedIDs(name string) []ent.Value { + switch name { + case memo.EdgeRelatedMemo: + ids := make([]ent.Value, 0, len(m.removedrelated_memo)) + for id := range m.removedrelated_memo { + ids = append(ids, id) + } + return ids + case memo.EdgeMemoRelation: + ids := make([]ent.Value, 0, len(m.removedmemo_relation)) + for id := range m.removedmemo_relation { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *MemoMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedrelated_memo { + edges = append(edges, memo.EdgeRelatedMemo) + } + if m.clearedmemo_relation { + edges = append(edges, memo.EdgeMemoRelation) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *MemoMutation) EdgeCleared(name string) bool { + switch name { + case memo.EdgeRelatedMemo: + return m.clearedrelated_memo + case memo.EdgeMemoRelation: + return m.clearedmemo_relation + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *MemoMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Memo unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *MemoMutation) ResetEdge(name string) error { + switch name { + case memo.EdgeRelatedMemo: + m.ResetRelatedMemo() + return nil + case memo.EdgeMemoRelation: + m.ResetMemoRelation() + return nil + } + return fmt.Errorf("unknown Memo edge %s", name) +} + +// MemoRelationMutation represents an operation that mutates the MemoRelation nodes in the graph. +type MemoRelationMutation struct { + config + op Op + typ string + id *int + _type *string + clearedFields map[string]struct{} + memo *int + clearedmemo bool + related_memo *int + clearedrelated_memo bool + done bool + oldValue func(context.Context) (*MemoRelation, error) + predicates []predicate.MemoRelation +} + +var _ ent.Mutation = (*MemoRelationMutation)(nil) + +// memorelationOption allows management of the mutation configuration using functional options. +type memorelationOption func(*MemoRelationMutation) + +// newMemoRelationMutation creates new mutation for the MemoRelation entity. +func newMemoRelationMutation(c config, op Op, opts ...memorelationOption) *MemoRelationMutation { + m := &MemoRelationMutation{ + config: c, + op: op, + typ: TypeMemoRelation, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withMemoRelationID sets the ID field of the mutation. +func withMemoRelationID(id int) memorelationOption { + return func(m *MemoRelationMutation) { + var ( + err error + once sync.Once + value *MemoRelation + ) + m.oldValue = func(ctx context.Context) (*MemoRelation, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().MemoRelation.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withMemoRelation sets the old MemoRelation of the mutation. +func withMemoRelation(node *MemoRelation) memorelationOption { + return func(m *MemoRelationMutation) { + m.oldValue = func(context.Context) (*MemoRelation, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m MemoRelationMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m MemoRelationMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *MemoRelationMutation) ID() (id int, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *MemoRelationMutation) IDs(ctx context.Context) ([]int, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().MemoRelation.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetType sets the "type" field. +func (m *MemoRelationMutation) SetType(s string) { + m._type = &s +} + +// GetType returns the value of the "type" field in the mutation. +func (m *MemoRelationMutation) GetType() (r string, exists bool) { + v := m._type + if v == nil { + return + } + return *v, true +} + +// OldType returns the old "type" field's value of the MemoRelation entity. +// If the MemoRelation object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoRelationMutation) OldType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldType: %w", err) + } + return oldValue.Type, nil +} + +// ResetType resets all changes to the "type" field. +func (m *MemoRelationMutation) ResetType() { + m._type = nil +} + +// SetMemoID sets the "memo_id" field. +func (m *MemoRelationMutation) SetMemoID(i int) { + m.memo = &i +} + +// MemoID returns the value of the "memo_id" field in the mutation. +func (m *MemoRelationMutation) MemoID() (r int, exists bool) { + v := m.memo + if v == nil { + return + } + return *v, true +} + +// OldMemoID returns the old "memo_id" field's value of the MemoRelation entity. +// If the MemoRelation object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoRelationMutation) OldMemoID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMemoID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMemoID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMemoID: %w", err) + } + return oldValue.MemoID, nil +} + +// ResetMemoID resets all changes to the "memo_id" field. +func (m *MemoRelationMutation) ResetMemoID() { + m.memo = nil +} + +// SetRelatedMemoID sets the "related_memo_id" field. +func (m *MemoRelationMutation) SetRelatedMemoID(i int) { + m.related_memo = &i +} + +// RelatedMemoID returns the value of the "related_memo_id" field in the mutation. +func (m *MemoRelationMutation) RelatedMemoID() (r int, exists bool) { + v := m.related_memo + if v == nil { + return + } + return *v, true +} + +// OldRelatedMemoID returns the old "related_memo_id" field's value of the MemoRelation entity. +// If the MemoRelation object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MemoRelationMutation) OldRelatedMemoID(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRelatedMemoID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRelatedMemoID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRelatedMemoID: %w", err) + } + return oldValue.RelatedMemoID, nil +} + +// ResetRelatedMemoID resets all changes to the "related_memo_id" field. +func (m *MemoRelationMutation) ResetRelatedMemoID() { + m.related_memo = nil +} + +// ClearMemo clears the "memo" edge to the Memo entity. +func (m *MemoRelationMutation) ClearMemo() { + m.clearedmemo = true + m.clearedFields[memorelation.FieldMemoID] = struct{}{} +} + +// MemoCleared reports if the "memo" edge to the Memo entity was cleared. +func (m *MemoRelationMutation) MemoCleared() bool { + return m.clearedmemo +} + +// MemoIDs returns the "memo" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// MemoID instead. It exists only for internal usage by the builders. +func (m *MemoRelationMutation) MemoIDs() (ids []int) { + if id := m.memo; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetMemo resets all changes to the "memo" edge. +func (m *MemoRelationMutation) ResetMemo() { + m.memo = nil + m.clearedmemo = false +} + +// ClearRelatedMemo clears the "related_memo" edge to the Memo entity. +func (m *MemoRelationMutation) ClearRelatedMemo() { + m.clearedrelated_memo = true + m.clearedFields[memorelation.FieldRelatedMemoID] = struct{}{} +} + +// RelatedMemoCleared reports if the "related_memo" edge to the Memo entity was cleared. +func (m *MemoRelationMutation) RelatedMemoCleared() bool { + return m.clearedrelated_memo +} + +// RelatedMemoIDs returns the "related_memo" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// RelatedMemoID instead. It exists only for internal usage by the builders. +func (m *MemoRelationMutation) RelatedMemoIDs() (ids []int) { + if id := m.related_memo; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetRelatedMemo resets all changes to the "related_memo" edge. +func (m *MemoRelationMutation) ResetRelatedMemo() { + m.related_memo = nil + m.clearedrelated_memo = false +} + +// Where appends a list predicates to the MemoRelationMutation builder. +func (m *MemoRelationMutation) Where(ps ...predicate.MemoRelation) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the MemoRelationMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *MemoRelationMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.MemoRelation, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *MemoRelationMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *MemoRelationMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (MemoRelation). +func (m *MemoRelationMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *MemoRelationMutation) Fields() []string { + fields := make([]string, 0, 3) + if m._type != nil { + fields = append(fields, memorelation.FieldType) + } + if m.memo != nil { + fields = append(fields, memorelation.FieldMemoID) + } + if m.related_memo != nil { + fields = append(fields, memorelation.FieldRelatedMemoID) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *MemoRelationMutation) Field(name string) (ent.Value, bool) { + switch name { + case memorelation.FieldType: + return m.GetType() + case memorelation.FieldMemoID: + return m.MemoID() + case memorelation.FieldRelatedMemoID: + return m.RelatedMemoID() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *MemoRelationMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case memorelation.FieldType: + return m.OldType(ctx) + case memorelation.FieldMemoID: + return m.OldMemoID(ctx) + case memorelation.FieldRelatedMemoID: + return m.OldRelatedMemoID(ctx) + } + return nil, fmt.Errorf("unknown MemoRelation field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MemoRelationMutation) SetField(name string, value ent.Value) error { + switch name { + case memorelation.FieldType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetType(v) + return nil + case memorelation.FieldMemoID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMemoID(v) + return nil + case memorelation.FieldRelatedMemoID: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRelatedMemoID(v) + return nil + } + return fmt.Errorf("unknown MemoRelation field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *MemoRelationMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *MemoRelationMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *MemoRelationMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown MemoRelation numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *MemoRelationMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *MemoRelationMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *MemoRelationMutation) ClearField(name string) error { + return fmt.Errorf("unknown MemoRelation nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *MemoRelationMutation) ResetField(name string) error { + switch name { + case memorelation.FieldType: + m.ResetType() + return nil + case memorelation.FieldMemoID: + m.ResetMemoID() + return nil + case memorelation.FieldRelatedMemoID: + m.ResetRelatedMemoID() + return nil + } + return fmt.Errorf("unknown MemoRelation field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *MemoRelationMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.memo != nil { + edges = append(edges, memorelation.EdgeMemo) + } + if m.related_memo != nil { + edges = append(edges, memorelation.EdgeRelatedMemo) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *MemoRelationMutation) AddedIDs(name string) []ent.Value { + switch name { + case memorelation.EdgeMemo: + if id := m.memo; id != nil { + return []ent.Value{*id} + } + case memorelation.EdgeRelatedMemo: + if id := m.related_memo; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *MemoRelationMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *MemoRelationMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *MemoRelationMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedmemo { + edges = append(edges, memorelation.EdgeMemo) + } + if m.clearedrelated_memo { + edges = append(edges, memorelation.EdgeRelatedMemo) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *MemoRelationMutation) EdgeCleared(name string) bool { + switch name { + case memorelation.EdgeMemo: + return m.clearedmemo + case memorelation.EdgeRelatedMemo: + return m.clearedrelated_memo + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *MemoRelationMutation) ClearEdge(name string) error { + switch name { + case memorelation.EdgeMemo: + m.ClearMemo() + return nil + case memorelation.EdgeRelatedMemo: + m.ClearRelatedMemo() + return nil + } + return fmt.Errorf("unknown MemoRelation unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *MemoRelationMutation) ResetEdge(name string) error { + switch name { + case memorelation.EdgeMemo: + m.ResetMemo() + return nil + case memorelation.EdgeRelatedMemo: + m.ResetRelatedMemo() + return nil + } + return fmt.Errorf("unknown MemoRelation edge %s", name) +} diff --git a/ent/predicate/predicate.go b/ent/predicate/predicate.go new file mode 100644 index 0000000000000..a9f176e2b8599 --- /dev/null +++ b/ent/predicate/predicate.go @@ -0,0 +1,13 @@ +// Code generated by ent, DO NOT EDIT. + +package predicate + +import ( + "entgo.io/ent/dialect/sql" +) + +// Memo is the predicate function for memo builders. +type Memo func(*sql.Selector) + +// MemoRelation is the predicate function for memorelation builders. +type MemoRelation func(*sql.Selector) diff --git a/ent/runtime.go b/ent/runtime.go new file mode 100644 index 0000000000000..c8aff120125bb --- /dev/null +++ b/ent/runtime.go @@ -0,0 +1,82 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "github.com/usememos/memos/ent/memo" + "github.com/usememos/memos/ent/schema" +) + +// The init function reads all schema descriptors with runtime code +// (default values, validators, hooks and policies) and stitches it +// to their package variables. +func init() { + memoFields := schema.Memo{}.Fields() + _ = memoFields + // memoDescResourceName is the schema descriptor for resource_name field. + memoDescResourceName := memoFields[1].Descriptor() + // memo.ResourceNameValidator is a validator for the "resource_name" field. It is called by the builders before save. + memo.ResourceNameValidator = func() func(string) error { + validators := memoDescResourceName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(resource_name string) error { + for _, fn := range fns { + if err := fn(resource_name); err != nil { + return err + } + } + return nil + } + }() + // memoDescCreatorID is the schema descriptor for creator_id field. + memoDescCreatorID := memoFields[2].Descriptor() + // memo.CreatorIDValidator is a validator for the "creator_id" field. It is called by the builders before save. + memo.CreatorIDValidator = memoDescCreatorID.Validators[0].(func(int) error) + // memoDescRowStatus is the schema descriptor for row_status field. + memoDescRowStatus := memoFields[5].Descriptor() + // memo.RowStatusValidator is a validator for the "row_status" field. It is called by the builders before save. + memo.RowStatusValidator = func() func(string) error { + validators := memoDescRowStatus.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(row_status string) error { + for _, fn := range fns { + if err := fn(row_status); err != nil { + return err + } + } + return nil + } + }() + // memoDescContent is the schema descriptor for content field. + memoDescContent := memoFields[6].Descriptor() + // memo.DefaultContent holds the default value on creation for the content field. + memo.DefaultContent = memoDescContent.Default.(string) + // memoDescVisibility is the schema descriptor for visibility field. + memoDescVisibility := memoFields[7].Descriptor() + // memo.VisibilityValidator is a validator for the "visibility" field. It is called by the builders before save. + memo.VisibilityValidator = func() func(string) error { + validators := memoDescVisibility.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(visibility string) error { + for _, fn := range fns { + if err := fn(visibility); err != nil { + return err + } + } + return nil + } + }() + // memoDescID is the schema descriptor for id field. + memoDescID := memoFields[0].Descriptor() + // memo.IDValidator is a validator for the "id" field. It is called by the builders before save. + memo.IDValidator = memoDescID.Validators[0].(func(int) error) +} diff --git a/ent/runtime/runtime.go b/ent/runtime/runtime.go new file mode 100644 index 0000000000000..2453d356faf3c --- /dev/null +++ b/ent/runtime/runtime.go @@ -0,0 +1,10 @@ +// Code generated by ent, DO NOT EDIT. + +package runtime + +// The schema-stitching logic is generated in github.com/usememos/memos/ent/runtime.go + +const ( + Version = "v0.12.5" // Version of ent codegen. + Sum = "h1:KREM5E4CSoej4zeGa88Ou/gfturAnpUv0mzAjch1sj4=" // Sum of ent codegen. +) diff --git a/ent/tx.go b/ent/tx.go new file mode 100644 index 0000000000000..e2dafaf0f2dd7 --- /dev/null +++ b/ent/tx.go @@ -0,0 +1,213 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "sync" + + "entgo.io/ent/dialect" +) + +// Tx is a transactional client that is created by calling Client.Tx(). +type Tx struct { + config + // Memo is the client for interacting with the Memo builders. + Memo *MemoClient + // MemoRelation is the client for interacting with the MemoRelation builders. + MemoRelation *MemoRelationClient + + // lazily loaded. + client *Client + clientOnce sync.Once + // ctx lives for the life of the transaction. It is + // the same context used by the underlying connection. + ctx context.Context +} + +type ( + // Committer is the interface that wraps the Commit method. + Committer interface { + Commit(context.Context, *Tx) error + } + + // The CommitFunc type is an adapter to allow the use of ordinary + // function as a Committer. If f is a function with the appropriate + // signature, CommitFunc(f) is a Committer that calls f. + CommitFunc func(context.Context, *Tx) error + + // CommitHook defines the "commit middleware". A function that gets a Committer + // and returns a Committer. For example: + // + // hook := func(next ent.Committer) ent.Committer { + // return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error { + // // Do some stuff before. + // if err := next.Commit(ctx, tx); err != nil { + // return err + // } + // // Do some stuff after. + // return nil + // }) + // } + // + CommitHook func(Committer) Committer +) + +// Commit calls f(ctx, m). +func (f CommitFunc) Commit(ctx context.Context, tx *Tx) error { + return f(ctx, tx) +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + txDriver := tx.config.driver.(*txDriver) + var fn Committer = CommitFunc(func(context.Context, *Tx) error { + return txDriver.tx.Commit() + }) + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() + for i := len(hooks) - 1; i >= 0; i-- { + fn = hooks[i](fn) + } + return fn.Commit(tx.ctx, tx) +} + +// OnCommit adds a hook to call on commit. +func (tx *Tx) OnCommit(f CommitHook) { + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() +} + +type ( + // Rollbacker is the interface that wraps the Rollback method. + Rollbacker interface { + Rollback(context.Context, *Tx) error + } + + // The RollbackFunc type is an adapter to allow the use of ordinary + // function as a Rollbacker. If f is a function with the appropriate + // signature, RollbackFunc(f) is a Rollbacker that calls f. + RollbackFunc func(context.Context, *Tx) error + + // RollbackHook defines the "rollback middleware". A function that gets a Rollbacker + // and returns a Rollbacker. For example: + // + // hook := func(next ent.Rollbacker) ent.Rollbacker { + // return ent.RollbackFunc(func(ctx context.Context, tx *ent.Tx) error { + // // Do some stuff before. + // if err := next.Rollback(ctx, tx); err != nil { + // return err + // } + // // Do some stuff after. + // return nil + // }) + // } + // + RollbackHook func(Rollbacker) Rollbacker +) + +// Rollback calls f(ctx, m). +func (f RollbackFunc) Rollback(ctx context.Context, tx *Tx) error { + return f(ctx, tx) +} + +// Rollback rollbacks the transaction. +func (tx *Tx) Rollback() error { + txDriver := tx.config.driver.(*txDriver) + var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { + return txDriver.tx.Rollback() + }) + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() + for i := len(hooks) - 1; i >= 0; i-- { + fn = hooks[i](fn) + } + return fn.Rollback(tx.ctx, tx) +} + +// OnRollback adds a hook to call on rollback. +func (tx *Tx) OnRollback(f RollbackHook) { + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() +} + +// Client returns a Client that binds to current transaction. +func (tx *Tx) Client() *Client { + tx.clientOnce.Do(func() { + tx.client = &Client{config: tx.config} + tx.client.init() + }) + return tx.client +} + +func (tx *Tx) init() { + tx.Memo = NewMemoClient(tx.config) + tx.MemoRelation = NewMemoRelationClient(tx.config) +} + +// txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. +// The idea is to support transactions without adding any extra code to the builders. +// When a builder calls to driver.Tx(), it gets the same dialect.Tx instance. +// Commit and Rollback are nop for the internal builders and the user must call one +// of them in order to commit or rollback the transaction. +// +// If a closed transaction is embedded in one of the generated entities, and the entity +// applies a query, for example: Memo.QueryXXX(), the query will be executed +// through the driver which created this transaction. +// +// Note that txDriver is not goroutine safe. +type txDriver struct { + // the driver we started the transaction from. + drv dialect.Driver + // tx is the underlying transaction. + tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook +} + +// newTx creates a new transactional driver. +func newTx(ctx context.Context, drv dialect.Driver) (*txDriver, error) { + tx, err := drv.Tx(ctx) + if err != nil { + return nil, err + } + return &txDriver{tx: tx, drv: drv}, nil +} + +// Tx returns the transaction wrapper (txDriver) to avoid Commit or Rollback calls +// from the internal builders. Should be called only by the internal builders. +func (tx *txDriver) Tx(context.Context) (dialect.Tx, error) { return tx, nil } + +// Dialect returns the dialect of the driver we started the transaction from. +func (tx *txDriver) Dialect() string { return tx.drv.Dialect() } + +// Close is a nop close. +func (*txDriver) Close() error { return nil } + +// Commit is a nop commit for the internal builders. +// User must call `Tx.Commit` in order to commit the transaction. +func (*txDriver) Commit() error { return nil } + +// Rollback is a nop rollback for the internal builders. +// User must call `Tx.Rollback` in order to rollback the transaction. +func (*txDriver) Rollback() error { return nil } + +// Exec calls tx.Exec. +func (tx *txDriver) Exec(ctx context.Context, query string, args, v any) error { + return tx.tx.Exec(ctx, query, args, v) +} + +// Query calls tx.Query. +func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error { + return tx.tx.Query(ctx, query, args, v) +} + +var _ dialect.Driver = (*txDriver)(nil) From ef5b193e682f13970fdcbe114550bbda9bab8816 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Thu, 8 Feb 2024 23:09:45 +0800 Subject: [PATCH 3/3] store: init ent dbv2 support Signed-off-by: Wei Zhang --- api/v2/memo_relation_service.go | 29 +++++++++------ api/v2/memo_service.go | 65 ++++++++++++++++++++++++++++----- api/v2/v2.go | 2 + bin/memos/main.go | 10 ++++- go.mod | 9 +++++ go.sum | 30 +++++++++++++++ store/dbv2/db.go | 50 +++++++++++++++++++++++++ store/store.go | 6 ++- 8 files changed, 178 insertions(+), 23 deletions(-) create mode 100644 store/dbv2/db.go diff --git a/api/v2/memo_relation_service.go b/api/v2/memo_relation_service.go index 3c20e43f11ccb..4cbfdd45f9e54 100644 --- a/api/v2/memo_relation_service.go +++ b/api/v2/memo_relation_service.go @@ -6,6 +6,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/usememos/memos/ent" + "github.com/usememos/memos/ent/memorelation" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" "github.com/usememos/memos/store" ) @@ -44,23 +46,26 @@ func (s *APIV2Service) SetMemoRelations(ctx context.Context, request *apiv2pb.Se func (s *APIV2Service) ListMemoRelations(ctx context.Context, request *apiv2pb.ListMemoRelationsRequest) (*apiv2pb.ListMemoRelationsResponse, error) { relationList := []*apiv2pb.MemoRelation{} - tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ - MemoID: &request.Id, - }) + tempList, err := s.Store.V2.MemoRelation. + Query(). + Where(memorelation.MemoID(int(request.Id))). + All(ctx) if err != nil { return nil, err } + for _, relation := range tempList { - relationList = append(relationList, convertMemoRelationFromStore(relation)) + relationList = append(relationList, convertMemoRelationFromStoreV2(relation)) } - tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ - RelatedMemoID: &request.Id, - }) + tempList, err = s.Store.V2.MemoRelation. + Query(). + Where(memorelation.RelatedMemoID(int(request.Id))). + All(ctx) if err != nil { return nil, err } for _, relation := range tempList { - relationList = append(relationList, convertMemoRelationFromStore(relation)) + relationList = append(relationList, convertMemoRelationFromStoreV2(relation)) } response := &apiv2pb.ListMemoRelationsResponse{ @@ -69,11 +74,11 @@ func (s *APIV2Service) ListMemoRelations(ctx context.Context, request *apiv2pb.L return response, nil } -func convertMemoRelationFromStore(memoRelation *store.MemoRelation) *apiv2pb.MemoRelation { +func convertMemoRelationFromStoreV2(memoRelation *ent.MemoRelation) *apiv2pb.MemoRelation { return &apiv2pb.MemoRelation{ - MemoId: memoRelation.MemoID, - RelatedMemoId: memoRelation.RelatedMemoID, - Type: convertMemoRelationTypeFromStore(memoRelation.Type), + MemoId: int32(memoRelation.MemoID), + RelatedMemoId: int32(memoRelation.RelatedMemoID), + Type: convertMemoRelationTypeFromStore(store.MemoRelationType(memoRelation.Type)), } } diff --git a/api/v2/memo_service.go b/api/v2/memo_service.go index ef0bff8a088af..aae5d211811f9 100644 --- a/api/v2/memo_service.go +++ b/api/v2/memo_service.go @@ -25,6 +25,8 @@ import ( storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/service/metric" "github.com/usememos/memos/store" + "github.com/usememos/memos/ent" + memotype "github.com/usememos/memos/ent/memo" ) const ( @@ -137,16 +139,18 @@ func (s *APIV2Service) ListMemos(ctx context.Context, request *apiv2pb.ListMemos } func (s *APIV2Service) GetMemo(ctx context.Context, request *apiv2pb.GetMemoRequest) (*apiv2pb.GetMemoResponse, error) { - memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ - ID: &request.Id, - }) + memo, err := s.Store.V2.Memo. + Query(). + Where(memotype.ID(int(request.Id))). + Only(ctx) if err != nil { + if ent.IsNotFound(err) { + return nil, status.Errorf(codes.NotFound, "memo not found") + } return nil, err } - if memo == nil { - return nil, status.Errorf(codes.NotFound, "memo not found") - } - if memo.Visibility != store.Public { + + if store.Visibility(memo.Visibility) != store.Public { user, err := getCurrentUser(ctx, s.Store) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user") @@ -154,12 +158,13 @@ func (s *APIV2Service) GetMemo(ctx context.Context, request *apiv2pb.GetMemoRequ if user == nil { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - if memo.Visibility == store.Private && memo.CreatorID != user.ID { + if store.Visibility(memo.Visibility) == store.Private && + int32(memo.CreatorID) != user.ID { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } } - memoMessage, err := s.convertMemoFromStore(ctx, memo) + memoMessage, err := s.convertMemoFromStoreV2(ctx, memo) if err != nil { return nil, errors.Wrap(err, "failed to convert memo") } @@ -509,6 +514,48 @@ func (s *APIV2Service) ExportMemos(ctx context.Context, request *apiv2pb.ExportM }, nil } +func (s *APIV2Service) convertMemoFromStoreV2(ctx context.Context, memo *ent.Memo) (*apiv2pb.Memo, error) { + displayTs := memo.CreatedTs + if displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx); err == nil && displayWithUpdatedTs { + displayTs = memo.UpdatedTs + } + + creatorID := int32(memo.CreatorID) + creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &creatorID}) + if err != nil { + return nil, errors.Wrap(err, "failed to get creator") + } + + memoID := int32(memo.ID) + listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &apiv2pb.ListMemoRelationsRequest{Id: memoID}) + if err != nil { + return nil, errors.Wrap(err, "failed to list memo relations") + } + + listMemoResourcesResponse, err := s.ListMemoResources(ctx, &apiv2pb.ListMemoResourcesRequest{Id: memoID}) + if err != nil { + return nil, errors.Wrap(err, "failed to list memo resources") + } + + return &apiv2pb.Memo{ + Id: int32(memo.ID), + Name: memo.ResourceName, + RowStatus: convertRowStatusFromStore(store.RowStatus(memo.RowStatus)), + Creator: fmt.Sprintf("%s%s", UserNamePrefix, creator.Username), + CreatorId: int32(memo.CreatorID), + CreateTime: timestamppb.New(memo.CreatedTs), + UpdateTime: timestamppb.New(memo.UpdatedTs), + DisplayTime: timestamppb.New(displayTs), + Content: memo.Content, + Visibility: convertVisibilityFromStore(store.Visibility(memo.Visibility)), + // TODO(kw): implement pinned + // Pinned: memo.Pinned, + // ParentId: memo.ParentID, + Relations: listMemoRelationsResponse.Relations, + Resources: listMemoResourcesResponse.Resources, + }, nil +} + func (s *APIV2Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*apiv2pb.Memo, error) { displayTs := memo.CreatedTs if displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx); err == nil && displayWithUpdatedTs { diff --git a/api/v2/v2.go b/api/v2/v2.go index 5cc1fd1499adf..595272859f7b1 100644 --- a/api/v2/v2.go +++ b/api/v2/v2.go @@ -14,6 +14,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" + "github.com/usememos/memos/ent" "github.com/usememos/memos/internal/log" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" "github.com/usememos/memos/server/profile" @@ -34,6 +35,7 @@ type APIV2Service struct { Secret string Profile *profile.Profile Store *store.Store + StoreV2 *ent.Client grpcServer *grpc.Server grpcServerPort int diff --git a/bin/memos/main.go b/bin/memos/main.go index 47a9ca6274e1b..11e591f77299f 100644 --- a/bin/memos/main.go +++ b/bin/memos/main.go @@ -20,6 +20,7 @@ import ( "github.com/usememos/memos/server/service/metric" "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" + "github.com/usememos/memos/store/dbv2" ) const ( @@ -60,7 +61,14 @@ var ( return } - storeInstance := store.New(dbDriver, profile) + dbv2, err := dbv2.NewDriver(profile) + if err != nil { + cancel() + log.Error("failed to create dbv2 driver", zap.Error(err)) + return + } + + storeInstance := store.New(dbDriver, dbv2, profile) s, err := server.NewServer(ctx, profile, storeInstance) if err != nil { cancel() diff --git a/go.mod b/go.mod index 2b28599078917..0bee15c92b7f3 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/usememos/memos go 1.22 require ( + entgo.io/ent v0.12.5 github.com/aws/aws-sdk-go-v2 v1.24.1 github.com/aws/aws-sdk-go-v2/config v1.26.6 github.com/aws/aws-sdk-go-v2/credentials v1.16.16 @@ -38,20 +39,27 @@ require ( ) require ( + ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 // indirect github.com/KyleBanks/depth v1.2.1 // indirect + github.com/agext/levenshtein v1.2.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect + github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-openapi/inflect v0.19.0 // indirect github.com/go-openapi/jsonpointer v0.20.2 // indirect github.com/go-openapi/jsonreference v0.20.4 // indirect github.com/go-openapi/spec v0.20.14 // indirect github.com/go-openapi/swag v0.22.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rs/cors v1.10.1 // indirect @@ -59,6 +67,7 @@ require ( github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect + github.com/zclconf/go-cty v1.8.0 // indirect golang.org/x/image v0.15.0 // indirect golang.org/x/tools v0.17.0 // indirect google.golang.org/genproto v0.0.0-20240205150955-31a09d347014 // indirect diff --git a/go.sum b/go.sum index 3d503d3cd5615..f90df588b759b 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,14 @@ +ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 h1:JnYs/y8RJ3+MiIUp+3RgyyeO48VHLAZimqiaZYnMKk8= +ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935/go.mod h1:isZrlzJ5cpoCoKFoY9knZug7Lq4pP1cm8g3XciLZ0Pw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +entgo.io/ent v0.12.5 h1:KREM5E4CSoej4zeGa88Ou/gfturAnpUv0mzAjch1sj4= +entgo.io/ent v0.12.5/go.mod h1:Y3JVAjtlIk8xVZYSn3t3mf8xlZIn5SAOXZQxD6kKI+Q= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= @@ -10,6 +16,8 @@ github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWX github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= +github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= +github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -19,6 +27,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6iT90AvPUL1NNfNw= +github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -127,6 +137,8 @@ github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgO github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNPXu/4= +github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4= github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q= github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs= github.com/go-openapi/jsonreference v0.20.4 h1:bKlDxQxQJgwpUSgOENiMPzCTBVuc7vTdXSSgNeAhojU= @@ -143,6 +155,8 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= +github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= @@ -162,6 +176,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -230,6 +245,8 @@ github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= +github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= @@ -277,6 +294,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= github.com/labstack/echo/v4 v4.11.4 h1:vDZmA+qNeh1pd/cCkEicDMrjtrnMGQ1QFI9gWN1zGq8= github.com/labstack/echo/v4 v4.11.4/go.mod h1:noh7EvLwqDsmh/X/HWKPUl1AjzJrhyptRyEbQJfxen8= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= @@ -312,6 +331,8 @@ github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3N github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM= +github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= @@ -410,6 +431,8 @@ github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6g github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -465,10 +488,14 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/vmihailenco/msgpack/v4 v4.3.12/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= +github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/yourselfhosted/gomark v0.0.0-20240203135813-6f2bb7ded891 h1:XZsVC80Y85+XLp+PtjwDCXI7KAec9S3STmSBOgkwVqQ= github.com/yourselfhosted/gomark v0.0.0-20240203135813-6f2bb7ded891/go.mod h1:dfl9FHGIw1oISjPc16u8n6/H/dngiVfdVRtS5+WJ4Js= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zclconf/go-cty v1.8.0 h1:s4AvqaeQzJIu3ndv4gVIhplVD0krU+bgrcLSVUnaWuA= +github.com/zclconf/go-cty v1.8.0/go.mod h1:vVKLxnk3puL4qRAv72AO+W99LUD4da90g3uUAzyuvAk= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= @@ -536,6 +563,7 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -595,6 +623,7 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= @@ -630,6 +659,7 @@ google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMt google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= diff --git a/store/dbv2/db.go b/store/dbv2/db.go new file mode 100644 index 0000000000000..4ca5974a7e562 --- /dev/null +++ b/store/dbv2/db.go @@ -0,0 +1,50 @@ +package dbv2 + +import ( + "fmt" + + "github.com/go-sql-driver/mysql" + + "github.com/usememos/memos/ent" + "github.com/usememos/memos/internal/log" + "github.com/usememos/memos/server/profile" +) + +// NewDBDriver creates new db driver based on profile. +func NewDriver(profile *profile.Profile) (*ent.Client, error) { + var driver *ent.Client + var err error + + switch profile.Driver { + case "mysql": + // TODO(kw): do we still need this? + // + // Open MySQL connection with parameter. + // multiStatements=true is required for migration. + // See more in: https://github.com/go-sql-driver/mysql#multistatements + dsn, err := mergeDSN(profile.DSN) + if err != nil { + log.Error(fmt.Sprintf("DSN %s error: %v", dsn, err)) + return nil, err + } + + driver, err = ent.Open("mysql", dsn) + default: + return nil, fmt.Errorf("unknown dbv2 driver") + } + if err != nil { + return nil, fmt.Errorf("failed to create db driver: %w", err) + } + + return driver, nil +} + +func mergeDSN(baseDSN string) (string, error) { + config, err := mysql.ParseDSN(baseDSN) + if err != nil { + return "", fmt.Errorf("failed to parse DSN %s: %w", baseDSN, err) + } + + config.MultiStatements = true + return config.FormatDSN(), nil +} diff --git a/store/store.go b/store/store.go index dcaf36b67ee84..2da6b0abc368e 100644 --- a/store/store.go +++ b/store/store.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/usememos/memos/ent" "github.com/usememos/memos/server/profile" ) @@ -15,12 +16,15 @@ type Store struct { userCache sync.Map // map[int]*User userSettingCache sync.Map // map[string]*UserSetting idpCache sync.Map // map[int]*IdentityProvider + + V2 *ent.Client } // New creates a new instance of Store. -func New(driver Driver, profile *profile.Profile) *Store { +func New(driver Driver, dbv2 *ent.Client, profile *profile.Profile) *Store { return &Store{ driver: driver, + V2: dbv2, Profile: profile, } }