diff --git a/enterprise/server/auth/auth.go b/enterprise/server/auth/auth.go index 8c3f26616f2..382da7c00da 100644 --- a/enterprise/server/auth/auth.go +++ b/enterprise/server/auth/auth.go @@ -491,7 +491,7 @@ func lookupUserFromSubID(env environment.Env, ctx context.Context, subID string) g.use_group_owned_executors, g.saml_idp_metadata_url, ug.role - FROM Groups AS g, UserGroups AS ug + FROM `+"`Groups`"+` AS g, UserGroups AS ug WHERE g.group_id = ug.group_group_id AND ug.membership_status = ? AND ug.user_user_id = ? diff --git a/enterprise/server/backends/userdb/userdb.go b/enterprise/server/backends/userdb/userdb.go index 2823dc9fd63..21b7b55b60a 100644 --- a/enterprise/server/backends/userdb/userdb.go +++ b/enterprise/server/backends/userdb/userdb.go @@ -155,7 +155,12 @@ func (d *UserDB) GetAPIKeys(ctx context.Context, groupID string) ([]*tables.APIK return nil, status.InvalidArgumentError("Group ID cannot be empty.") } - query := d.h.Raw(`SELECT api_key_id, value, label, perms, capabilities FROM APIKeys WHERE group_id = ?`, groupID) + query := d.h.Raw(` + SELECT api_key_id, value, label, perms, capabilities + FROM APIKeys + WHERE group_id = ? + ORDER BY label ASC + `, groupID) rows, err := query.Rows() if err != nil { return nil, err @@ -331,7 +336,7 @@ func (d *UserDB) InsertOrUpdateGroup(ctx context.Context, g *tables.Group) (stri groupID = g.GroupID res := tx.Exec(` - UPDATE Groups SET name = ?, url_identifier = ?, owned_domain = ?, sharing_enabled = ?, + UPDATE `+"`Groups`"+` SET name = ?, url_identifier = ?, owned_domain = ?, sharing_enabled = ?, use_group_owned_executors = ? WHERE group_id = ?`, g.Name, g.URLIdentifier, g.OwnedDomain, g.SharingEnabled, g.UseGroupOwnedExecutors, @@ -408,6 +413,8 @@ func (d *UserDB) GetGroupUsers(ctx context.Context, groupID string, statuses []g orQuery, orArgs := o.Build() q = q.AddWhereClause("("+orQuery+")", orArgs...) + q.SetOrderBy(`u.email`, true /*=ascending*/) + qString, qArgs := q.Build() rows, err := d.h.Raw(qString, qArgs...).Rows() if err != nil { diff --git a/enterprise/server/scheduling/scheduler_server/scheduler_server.go b/enterprise/server/scheduling/scheduler_server/scheduler_server.go index a51ccc97e75..a3a8bff3715 100644 --- a/enterprise/server/scheduling/scheduler_server/scheduler_server.go +++ b/enterprise/server/scheduling/scheduler_server/scheduler_server.go @@ -70,6 +70,8 @@ const ( maxUnclaimedTasksTracked = 10_000 // TTL for sets used to track unclaimed tasks in Redis. TTL is extended when new tasks are added. unclaimedTaskSetTTL = 1 * time.Hour + // Unclaimed tasks older than this are removed from the unclaimed tasks list. + unclaimedTaskMaxAge = 2 * time.Hour unusedSchedulerClientExpiration = 5 * time.Minute unusedSchedulerClientCheckInterval = 1 * time.Minute @@ -551,8 +553,17 @@ func (np *nodePool) AddUnclaimedTask(ctx context.Context, taskID string) error { if n > maxUnclaimedTasksTracked { // Trim the oldest tasks. We use the task insertion timestamp as the score so the oldest task is at rank 0, next // oldest is at rank 1 and so on. We subtract 1 because the indexes are inclusive. - return np.rdb.ZRemRangeByRank(ctx, key, 0, n-maxUnclaimedTasksTracked-1).Err() + if err := np.rdb.ZRemRangeByRank(ctx, key, 0, n-maxUnclaimedTasksTracked-1).Err(); err != nil { + log.Warningf("Error trimming unclaimed tasks: %s", err) + } + } + + // Also trim any stale tasks from the set. The data is stored in score order so this is a cheap operation. + cutoff := time.Now().Add(-unclaimedTaskMaxAge).Unix() + if err := np.rdb.ZRemRangeByScore(ctx, key, "0", strconv.FormatInt(cutoff, 10)).Err(); err != nil { + log.Warningf("Error deleting old unclaimed tasks: %s", err) } + return nil } @@ -1207,9 +1218,12 @@ func minInt(i, j int) int { } type enqueueTaskReservationOpts struct { - numReplicas int - maxAttempts int - alwaysScheduleLocally bool + numReplicas int + maxAttempts int + // This option determines whether tasks should be scheduled only on executors connected to this scheduler. + // If false, this scheduler will make RPCs to other schedulers to have them enqueue tasks on their connected + // executors. + scheduleOnConnectedExecutors bool } func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRequest *scpb.EnqueueTaskReservationRequest, serializedTask []byte, opts enqueueTaskReservationOpts) error { @@ -1228,9 +1242,13 @@ func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRe return err } - err = nodeBalancer.AddUnclaimedTask(ctx, enqueueRequest.GetTaskId()) - if err != nil { - log.Warningf("Could not add task to unclaimed task list: %s", err) + // We only want to add the unclaimed task once on the "master" scheduler. + // scheduleOnConnectedExecutors implies that we are enqueuing task reservations on behalf of another scheduler. + if !opts.scheduleOnConnectedExecutors { + err = nodeBalancer.AddUnclaimedTask(ctx, enqueueRequest.GetTaskId()) + if err != nil { + log.Warningf("Could not add task to unclaimed task list: %s", err) + } } probeCount := minInt(opts.numReplicas, nodeCount) @@ -1277,7 +1295,7 @@ func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRe preferredNode = nil } else { nodes = nodeBalancer.nodes - if opts.alwaysScheduleLocally { + if opts.scheduleOnConnectedExecutors { nodes = nodeBalancer.connectedExecutors } if len(nodes) == 0 { @@ -1300,7 +1318,7 @@ func (s *SchedulerServer) enqueueTaskReservations(ctx context.Context, enqueueRe enqueueRequest.ExecutorId = node.GetExecutorID() enqueueStart := time.Now() - if opts.alwaysScheduleLocally { + if opts.scheduleOnConnectedExecutors { if node.handle == nil { log.Errorf("nil handle for a local executor %q", node.GetExecutorID()) continue @@ -1360,8 +1378,8 @@ func (s *SchedulerServer) ScheduleTask(ctx context.Context, req *scpb.ScheduleTa } opts := enqueueTaskReservationOpts{ - numReplicas: probesPerTask, - alwaysScheduleLocally: false, + numReplicas: probesPerTask, + scheduleOnConnectedExecutors: false, } if err := s.enqueueTaskReservations(ctx, enqueueRequest, req.GetSerializedTask(), opts); err != nil { return nil, err @@ -1373,9 +1391,9 @@ func (s *SchedulerServer) EnqueueTaskReservation(ctx context.Context, req *scpb. // TODO(vadim): verify user is authorized to use executor pool opts := enqueueTaskReservationOpts{ - numReplicas: 1, - maxAttempts: 10, - alwaysScheduleLocally: true, + numReplicas: 1, + maxAttempts: 10, + scheduleOnConnectedExecutors: true, } if err := s.enqueueTaskReservations(ctx, req, nil /*=serializedTask*/, opts); err != nil { return nil, err @@ -1405,8 +1423,8 @@ func (s *SchedulerServer) reEnqueueTask(ctx context.Context, taskID string, numR SchedulingMetadata: task.metadata, } opts := enqueueTaskReservationOpts{ - numReplicas: numReplicas, - alwaysScheduleLocally: false, + numReplicas: numReplicas, + scheduleOnConnectedExecutors: false, } if err := s.enqueueTaskReservations(ctx, enqueueRequest, task.serializedTask, opts); err != nil { return err diff --git a/enterprise/server/util/cacheproxy/BUILD b/enterprise/server/util/cacheproxy/BUILD index e533b24ab48..5e24178b391 100644 --- a/enterprise/server/util/cacheproxy/BUILD +++ b/enterprise/server/util/cacheproxy/BUILD @@ -14,6 +14,7 @@ go_library( "//server/environment", "//server/interfaces", "//server/util/alert", + "//server/util/bytebufferpool", "//server/util/devnull", "//server/util/grpc_client", "//server/util/grpc_server", diff --git a/enterprise/server/util/cacheproxy/cacheproxy.go b/enterprise/server/util/cacheproxy/cacheproxy.go index ebdb02b93db..fd187610f96 100644 --- a/enterprise/server/util/cacheproxy/cacheproxy.go +++ b/enterprise/server/util/cacheproxy/cacheproxy.go @@ -12,6 +12,7 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/environment" "github.com/buildbuddy-io/buildbuddy/server/interfaces" "github.com/buildbuddy-io/buildbuddy/server/util/alert" + "github.com/buildbuddy-io/buildbuddy/server/util/bytebufferpool" "github.com/buildbuddy-io/buildbuddy/server/util/devnull" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_client" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_server" @@ -36,6 +37,7 @@ type CacheProxy struct { env environment.Env cache interfaces.Cache log log.Logger + bufferPool *bytebufferpool.Pool mu *sync.Mutex server *grpc.Server clients map[string]*dcClient @@ -49,6 +51,7 @@ func NewCacheProxy(env environment.Env, c interfaces.Cache, listenAddr string) * env: env, cache: c, log: log.NamedSubLogger(fmt.Sprintf("CacheProxy(%s)", listenAddr)), + bufferPool: bytebufferpool.New(readBufSizeBytes), listenAddr: listenAddr, mu: &sync.Mutex{}, // server goes here @@ -255,8 +258,9 @@ func (c *CacheProxy) Read(req *dcpb.ReadRequest, stream dcpb.DistributedCache_Re if d.GetSizeBytes() > 0 && d.GetSizeBytes() < bufSize { bufSize = d.GetSizeBytes() } - copyBuf := make([]byte, bufSize) - _, err = io.CopyBuffer(&streamWriter{stream}, reader, copyBuf) + copyBuf := c.bufferPool.Get(bufSize) + _, err = io.CopyBuffer(&streamWriter{stream}, reader, copyBuf[:bufSize]) + c.bufferPool.Put(copyBuf) c.log.Debugf("Read(%q) succeeded (user prefix: %s)", IsolationToString(req.GetIsolation())+d.GetHash(), up) return err } diff --git a/server/backends/github/BUILD b/server/backends/github/BUILD index 5c3f36f1be7..e17f25cdd45 100644 --- a/server/backends/github/BUILD +++ b/server/backends/github/BUILD @@ -8,9 +8,11 @@ go_library( deps = [ "//server/environment", "//server/tables", + "//server/util/authutil", "//server/util/log", "//server/util/perms", "//server/util/random", + "//server/util/role", "//server/util/status", ], ) diff --git a/server/backends/github/github.go b/server/backends/github/github.go index 75edf753d0c..a54aec38ded 100644 --- a/server/backends/github/github.go +++ b/server/backends/github/github.go @@ -13,9 +13,11 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/environment" "github.com/buildbuddy-io/buildbuddy/server/tables" + "github.com/buildbuddy-io/buildbuddy/server/util/authutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/perms" "github.com/buildbuddy-io/buildbuddy/server/util/random" + "github.com/buildbuddy-io/buildbuddy/server/util/role" "github.com/buildbuddy-io/buildbuddy/server/util/status" ) @@ -138,8 +140,13 @@ func (c *GithubClient) Link(w http.ResponseWriter, r *http.Request) { // Restore group ID from cookie. groupID := getCookie(r, groupIDCookieName) - if err := perms.AuthorizeGroupAccess(r.Context(), c.env, groupID); err != nil { - redirectWithError(w, r, status.PermissionDeniedErrorf("Group auth failed; not linking GitHub account: %s", err.Error())) + u, err := perms.AuthenticatedUser(r.Context(), c.env) + if err != nil { + redirectWithError(w, r, status.WrapError(err, "Failed to link GitHub account")) + return + } + if err := authutil.AuthorizeGroupRole(u, groupID, role.Admin); err != nil { + redirectWithError(w, r, status.WrapError(err, "Failed to link GitHub account")) return } @@ -150,7 +157,7 @@ func (c *GithubClient) Link(w http.ResponseWriter, r *http.Request) { } err = dbHandle.Exec( - "UPDATE Groups SET github_token = ? WHERE group_id = ?", + `UPDATE `+"`Groups`"+` SET github_token = ? WHERE group_id = ?`, accessTokenResponse.AccessToken, groupID).Error if err != nil { redirectWithError(w, r, status.PermissionDeniedErrorf("Error linking github account to user: %v", err)) diff --git a/server/build_event_protocol/build_event_handler/build_event_handler.go b/server/build_event_protocol/build_event_handler/build_event_handler.go index c42ff98ef65..4fac183273c 100644 --- a/server/build_event_protocol/build_event_handler/build_event_handler.go +++ b/server/build_event_protocol/build_event_handler/build_event_handler.go @@ -53,6 +53,12 @@ const ( // How many workers to spin up for writing cache stats to the DB. numStatsRecorderWorkers = 8 + + // How many workers to spin up for notifying webhooks. + numWebhookNotifierWorkers = 16 + + // How long to wait before giving up on webhook requests. + webhookNotifyTimeout = 1 * time.Minute ) var ( @@ -67,11 +73,15 @@ type BuildEventHandler struct { func NewBuildEventHandler(env environment.Env) *BuildEventHandler { openChannels := &sync.WaitGroup{} - statsRecorder := newStatsRecorder(env, openChannels) + statsRecorded := make(chan *inpb.Invocation, 4096) + statsRecorder := newStatsRecorder(env, openChannels, statsRecorded) + webhookNotifier := newWebhookNotifier(env, statsRecorded) statsRecorder.Start() + webhookNotifier.Start() env.GetHealthChecker().RegisterShutdownFunction(func(ctx context.Context) error { statsRecorder.Stop() + webhookNotifier.Stop() return nil }) @@ -118,7 +128,7 @@ func (b *BuildEventHandler) OpenChannel(ctx context.Context, iid string) interfa // invocation. These tasks are enqueued to statsRecorder and executed in the // background. type recordStatsTask struct { - invocationID string + invocation *inpb.Invocation // createdAt is the time at which this task was created. createdAt time.Time @@ -129,22 +139,26 @@ type recordStatsTask struct { type statsRecorder struct { env environment.Env openChannels *sync.WaitGroup - eg errgroup.Group + // statsRecorded is a channel that should be notified after the statsRecorder + // collects stats and flushes them to the DB. + statsRecorded chan<- *inpb.Invocation + eg errgroup.Group mu sync.Mutex // protects(tasks, stopped) tasks chan *recordStatsTask stopped bool } -func newStatsRecorder(env environment.Env, openChannels *sync.WaitGroup) *statsRecorder { +func newStatsRecorder(env environment.Env, openChannels *sync.WaitGroup, statsRecorded chan<- *inpb.Invocation) *statsRecorder { return &statsRecorder{ - env: env, - openChannels: openChannels, - tasks: make(chan *recordStatsTask, 4096), + env: env, + openChannels: openChannels, + statsRecorded: statsRecorded, + tasks: make(chan *recordStatsTask, 4096), } } -func (r *statsRecorder) MarkFinalized(invocationID string) { +func (r *statsRecorder) MarkFinalized(invocation *inpb.Invocation) { r.mu.Lock() defer r.mu.Unlock() @@ -152,12 +166,12 @@ func (r *statsRecorder) MarkFinalized(invocationID string) { alert.UnexpectedEvent( "stats_recorder_finalize_after_shutdown", "Invocation %q was marked finalized after the stats recorder was shut down.", - invocationID) + invocation.GetInvocationId()) return } req := &recordStatsTask{ - invocationID: invocationID, - createdAt: time.Now(), + invocation: invocation, + createdAt: time.Now(), } select { case r.tasks <- req: @@ -177,9 +191,10 @@ func (r *statsRecorder) Start() { // finalized, rather than relative to now. Otherwise each worker would be // unnecessarily throttled. time.Sleep(time.Until(task.createdAt.Add(cacheStatsFinalizationDelay))) - ti := &tables.Invocation{InvocationID: task.invocationID} - if stats := hit_tracker.CollectCacheStats(ctx, r.env, task.invocationID); stats != nil { + ti := &tables.Invocation{InvocationID: task.invocation.GetInvocationId()} + if stats := hit_tracker.CollectCacheStats(ctx, r.env, task.invocation.GetInvocationId()); stats != nil { fillInvocationFromCacheStats(stats, ti) + task.invocation.CacheStats = stats } if err := r.env.GetInvocationDB().InsertOrUpdateInvocation(ctx, ti); err != nil { log.Errorf("Failed to write cache stats for invocation: %s", err) @@ -187,7 +202,16 @@ func (r *statsRecorder) Start() { // Cleanup regardless of whether the stats are flushed successfully to // the DB (since we won't retry the flush and we don't need these stats // for any other purpose). - hit_tracker.CleanupCacheStats(ctx, r.env, task.invocationID) + hit_tracker.CleanupCacheStats(ctx, r.env, task.invocation.GetInvocationId()) + + // Once cache stats are populated, notify the statsRecorded channel in a + // non-blocking fashion. + select { + case r.statsRecorded <- task.invocation: + break + default: + log.Warningf("Failed to notify stats recorder listeners: channel buffer is full") + } } return nil }) @@ -200,14 +224,87 @@ func (r *statsRecorder) Stop() { r.openChannels.Wait() r.mu.Lock() - defer r.mu.Unlock() - r.stopped = true close(r.tasks) + r.mu.Unlock() if err := r.eg.Wait(); err != nil { log.Error(err.Error()) } + + close(r.statsRecorded) +} + +type notifyWebhookTask struct { + hook interfaces.Webhook + invocation *inpb.Invocation +} + +func notifyWithTimeout(ctx context.Context, t *notifyWebhookTask) error { + ctx, cancel := context.WithTimeout(ctx, webhookNotifyTimeout) + defer cancel() + return t.hook.NotifyComplete(ctx, t.invocation) +} + +// webhookNotifier listens for invocations to be finalized (including stats) +// and notifies webhooks. +type webhookNotifier struct { + env environment.Env + statsRecorded <-chan *inpb.Invocation + + tasks chan *notifyWebhookTask + eg errgroup.Group +} + +func newWebhookNotifier(env environment.Env, statsRecorded <-chan *inpb.Invocation) *webhookNotifier { + return &webhookNotifier{ + env: env, + statsRecorded: statsRecorded, + tasks: make(chan *notifyWebhookTask, 4096), + } +} + +func (w *webhookNotifier) Start() { + w.eg = errgroup.Group{} + ctx := context.Background() + + w.eg.Go(func() error { + // Listen for invocations that have been finalized by the stats recorder, + // and start a notify webhook task for each webhook. + for invocation := range w.statsRecorded { + // Don't call webhooks for disconnected invocations. + if invocation.GetInvocationStatus() == inpb.Invocation_DISCONNECTED_INVOCATION_STATUS { + continue + } + + for _, hook := range w.env.GetWebhooks() { + w.tasks <- ¬ifyWebhookTask{ + hook: hook, + invocation: invocation, + } + } + } + return nil + }) + + for i := 0; i < numWebhookNotifierWorkers; i++ { + w.eg.Go(func() error { + for task := range w.tasks { + if err := notifyWithTimeout(ctx, task); err != nil { + log.Warningf("Failed to notify webhook for invocation %s: %s", task.invocation.GetInvocationId(), err) + } + } + return nil + }) + } +} + +func (w *webhookNotifier) Stop() { + close(w.tasks) + + if err := w.eg.Wait(); err != nil { + log.Error(err.Error()) + } } func isFinalEvent(obe *pepb.OrderedBuildEvent) bool { @@ -321,7 +418,7 @@ func (e *EventChannel) MarkInvocationDisconnected(ctx context.Context, iid strin return err } - e.statsRecorder.MarkFinalized(iid) + e.statsRecorder.MarkFinalized(invocation) return nil } @@ -397,19 +494,7 @@ func (e *EventChannel) FinalizeInvocation(iid string) error { return err } - e.statsRecorder.MarkFinalized(iid) - - // Notify our webhooks, if we have any. - for _, hook := range e.env.GetWebhooks() { - hook := hook // copy loopvar to local var for closure capture - go func() { - // We use context background here because the request context will - // be closed soon and we don't want to block while calling webhooks. - if err := hook.NotifyComplete(context.Background(), invocation); err != nil { - log.Warningf("Error calling webhook: %s", err) - } - }() - } + e.statsRecorder.MarkFinalized(invocation) if searcher := e.env.GetInvocationSearchService(); searcher != nil { go func() { if err := searcher.IndexInvocation(context.Background(), invocation); err != nil { diff --git a/server/http/role_filter/BUILD b/server/http/role_filter/BUILD index 2f8bc3f62e9..bd505c107cf 100644 --- a/server/http/role_filter/BUILD +++ b/server/http/role_filter/BUILD @@ -7,9 +7,11 @@ go_library( visibility = ["//visibility:public"], deps = [ "//server/environment", + "//server/util/authutil", "//server/util/perms", "//server/util/request_context", "//server/util/role", + "//server/util/status", ], ) diff --git a/server/http/role_filter/role_filter.go b/server/http/role_filter/role_filter.go index a93bea9b4ca..adeebd78081 100644 --- a/server/http/role_filter/role_filter.go +++ b/server/http/role_filter/role_filter.go @@ -4,8 +4,10 @@ import ( "net/http" "github.com/buildbuddy-io/buildbuddy/server/environment" + "github.com/buildbuddy-io/buildbuddy/server/util/authutil" "github.com/buildbuddy-io/buildbuddy/server/util/perms" "github.com/buildbuddy-io/buildbuddy/server/util/role" + "github.com/buildbuddy-io/buildbuddy/server/util/status" requestcontext "github.com/buildbuddy-io/buildbuddy/server/util/request_context" ) @@ -99,22 +101,13 @@ func AuthorizeSelectedGroupRole(env environment.Env, next http.Handler) http.Han return } - uRole := role.None - for _, m := range u.GetGroupMemberships() { - if m.GroupID == reqCtx.GetGroupId() { - uRole = m.Role - break - } - } - if uRole == role.None { - // User was probably removed from their org during their current UI - // session. - http.Error(w, `You do not have access to the requested organization.`, http.StatusForbidden) - return + allowedRoles := role.Admin | role.Developer + if stringSliceContains(GroupAdminOnlyRPCs, rpcName) { + allowedRoles = role.Admin } - if stringSliceContains(GroupAdminOnlyRPCs, rpcName) && (uRole&role.Admin != role.Admin) { - http.Error(w, `This action can only be performed by administrators of the organization.`, http.StatusForbidden) + if err := authutil.AuthorizeGroupRole(u, reqCtx.GetGroupId(), allowedRoles); err != nil { + http.Error(w, status.Message(err), http.StatusForbidden) return } diff --git a/server/remote_cache/byte_stream_server/BUILD b/server/remote_cache/byte_stream_server/BUILD index 2cb617932ab..57911708d95 100644 --- a/server/remote_cache/byte_stream_server/BUILD +++ b/server/remote_cache/byte_stream_server/BUILD @@ -13,6 +13,7 @@ go_library( "//server/remote_cache/digest", "//server/remote_cache/hit_tracker", "//server/remote_cache/namespace", + "//server/util/bytebufferpool", "//server/util/capabilities", "//server/util/devnull", "//server/util/prefix", @@ -35,6 +36,7 @@ go_test( "//server/util/prefix", "//server/util/random", "//server/util/status", + "@com_github_stretchr_testify//require", "@go_googleapis//google/bytestream:bytestream_go_proto", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//status", diff --git a/server/remote_cache/byte_stream_server/byte_stream_server.go b/server/remote_cache/byte_stream_server/byte_stream_server.go index d74deb30281..c06c6148697 100644 --- a/server/remote_cache/byte_stream_server/byte_stream_server.go +++ b/server/remote_cache/byte_stream_server/byte_stream_server.go @@ -12,6 +12,7 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/hit_tracker" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/namespace" + "github.com/buildbuddy-io/buildbuddy/server/util/bytebufferpool" "github.com/buildbuddy-io/buildbuddy/server/util/capabilities" "github.com/buildbuddy-io/buildbuddy/server/util/devnull" "github.com/buildbuddy-io/buildbuddy/server/util/prefix" @@ -28,8 +29,9 @@ const ( ) type ByteStreamServer struct { - env environment.Env - cache interfaces.Cache + env environment.Env + cache interfaces.Cache + bufferPool *bytebufferpool.Pool } func NewByteStreamServer(env environment.Env) (*ByteStreamServer, error) { @@ -38,8 +40,9 @@ func NewByteStreamServer(env environment.Env) (*ByteStreamServer, error) { return nil, status.FailedPreconditionError("A cache is required to enable the ByteStreamServer") } return &ByteStreamServer{ - env: env, - cache: cache, + env: env, + cache: cache, + bufferPool: bytebufferpool.New(readBufSizeBytes), }, nil } @@ -116,8 +119,9 @@ func (s *ByteStreamServer) Read(req *bspb.ReadRequest, stream bspb.ByteStream_Re if d.GetSizeBytes() > 0 && d.GetSizeBytes() < bufSize { bufSize = d.GetSizeBytes() } - copyBuf := make([]byte, bufSize) - _, err = io.CopyBuffer(&streamWriter{stream}, reader, copyBuf) + copyBuf := s.bufferPool.Get(bufSize) + _, err = io.CopyBuffer(&streamWriter{stream}, reader, copyBuf[:bufSize]) + s.bufferPool.Put(copyBuf) if err == nil { downloadTracker.Close() } diff --git a/server/remote_cache/byte_stream_server/byte_stream_server_test.go b/server/remote_cache/byte_stream_server/byte_stream_server_test.go index 45b128891c9..eae973fcc7a 100644 --- a/server/remote_cache/byte_stream_server/byte_stream_server_test.go +++ b/server/remote_cache/byte_stream_server/byte_stream_server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "strings" "testing" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/cachetools" @@ -13,6 +14,7 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/util/prefix" "github.com/buildbuddy-io/buildbuddy/server/util/random" "github.com/buildbuddy-io/buildbuddy/server/util/status" + "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -32,7 +34,8 @@ func runByteStreamServer(ctx context.Context, env *testenv.TestEnv, t *testing.T go runFunc() - clientConn, err := env.LocalGRPCConn(ctx) + // TODO(vadim): can we remove the MsgSize override from the default options? + clientConn, err := env.LocalGRPCConn(ctx, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(4*1024*1024))) if err != nil { t.Error(err) } @@ -205,3 +208,27 @@ func TestRPCTooLongWrite(t *testing.T) { t.Fatalf("Expected data loss error but got %s", err) } } + +// Tests Read/Write of a blob that exceeds the default gRPC message size. +func TestRPCReadWriteLargeBlob(t *testing.T) { + ctx := context.Background() + te := testenv.GetTestEnv(t) + clientConn := runByteStreamServer(ctx, te, t) + bsClient := bspb.NewByteStreamClient(clientConn) + + blob, err := random.RandomString(10_000_000) + require.NoError(t, err) + d, err := digest.Compute(strings.NewReader(blob)) + require.NoError(t, err) + instanceNameDigest := digest.NewInstanceNameDigest(d, "") + + // Write + _, err = cachetools.UploadFromReader(ctx, bsClient, instanceNameDigest, strings.NewReader(blob)) + require.NoError(t, err) + + // Read + var buf bytes.Buffer + err = readBlob(ctx, bsClient, instanceNameDigest, &buf, 0) + require.NoError(t, err) + require.Equal(t, blob, string(buf.Bytes())) +} diff --git a/server/tables/tables.go b/server/tables/tables.go index c377fdeeb55..3d89865f469 100644 --- a/server/tables/tables.go +++ b/server/tables/tables.go @@ -531,7 +531,7 @@ func PreAutoMigrate(db *gorm.DB) ([]PostAutoMigrateLogic, error) { } // Before creating a unique index, need to replace empty strings with NULL. if !m.HasIndex("Groups", "url_identifier_unique_index") { - if err := db.Exec(`UPDATE Groups SET url_identifier = NULL WHERE url_identifier = ""`).Error; err != nil { + if err := db.Exec(`UPDATE ` + "`Groups`" + ` SET url_identifier = NULL WHERE url_identifier = ""`).Error; err != nil { return nil, err } } diff --git a/server/testutil/testenv/testenv.go b/server/testutil/testenv/testenv.go index cd9a9c13e55..f2de16c48db 100644 --- a/server/testutil/testenv/testenv.go +++ b/server/testutil/testenv/testenv.go @@ -82,10 +82,11 @@ func (te *TestEnv) LocalGRPCServer() (*grpc.Server, func()) { return te.GRPCServer(te.lis) } -func (te *TestEnv) LocalGRPCConn(ctx context.Context) (*grpc.ClientConn, error) { +func (te *TestEnv) LocalGRPCConn(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { dialOptions := grpc_client.CommonGRPCClientOptions() dialOptions = append(dialOptions, grpc.WithContextDialer(te.bufDialer)) dialOptions = append(dialOptions, grpc.WithInsecure()) + dialOptions = append(dialOptions, opts...) return grpc.DialContext(ctx, "bufnet", dialOptions...) } diff --git a/server/util/authutil/BUILD b/server/util/authutil/BUILD new file mode 100644 index 00000000000..29fea223ff9 --- /dev/null +++ b/server/util/authutil/BUILD @@ -0,0 +1,13 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "authutil", + srcs = ["authutil.go"], + importpath = "github.com/buildbuddy-io/buildbuddy/server/util/authutil", + visibility = ["//visibility:public"], + deps = [ + "//server/interfaces", + "//server/util/role", + "//server/util/status", + ], +) diff --git a/server/util/authutil/authutil.go b/server/util/authutil/authutil.go new file mode 100644 index 00000000000..b6164807c09 --- /dev/null +++ b/server/util/authutil/authutil.go @@ -0,0 +1,28 @@ +package authutil + +import ( + "github.com/buildbuddy-io/buildbuddy/server/interfaces" + "github.com/buildbuddy-io/buildbuddy/server/util/role" + "github.com/buildbuddy-io/buildbuddy/server/util/status" +) + +// AuthorizeGroupRole checks whether the given user has any of the allowed roles +// within the given group. +func AuthorizeGroupRole(u interfaces.UserInfo, groupID string, allowedRoles role.Role) error { + r := role.None + for _, m := range u.GetGroupMemberships() { + if m.GroupID == groupID { + r = m.Role + break + } + } + if r == role.None { + // User is not a member of the group at all; they were probably removed from + // their org during their current UI session. + return status.PermissionDeniedError("You do not have access to the requested organization") + } + if r&allowedRoles == 0 { + return status.PermissionDeniedError("You do not have the appropriate role within this organization") + } + return nil +} diff --git a/server/util/bytebufferpool/BUILD b/server/util/bytebufferpool/BUILD new file mode 100644 index 00000000000..025d43ee20a --- /dev/null +++ b/server/util/bytebufferpool/BUILD @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "bytebufferpool", + srcs = ["bytebufferpool.go"], + importpath = "github.com/buildbuddy-io/buildbuddy/server/util/bytebufferpool", + visibility = ["//visibility:public"], +) + +go_test( + name = "bytebufferpool_test", + srcs = ["bytebufferpool_test.go"], + deps = [ + ":bytebufferpool", + "@com_github_stretchr_testify//assert", + ], +) diff --git a/server/util/bytebufferpool/bytebufferpool.go b/server/util/bytebufferpool/bytebufferpool.go new file mode 100644 index 00000000000..34c2104ec5a --- /dev/null +++ b/server/util/bytebufferpool/bytebufferpool.go @@ -0,0 +1,59 @@ +package bytebufferpool + +import ( + "math/bits" + "sync" +) + +// Pool is a wrapper around `sync.Pool` that manages pool for buffers of different lengths of n (using power of 2). +type Pool struct { + // pools contain slices of lengths that are powers of 2. the slice itself is indexed by the exponent. + // index 0 containing a pool of 2^0 length slices, index 1 containing 2^1 length slices and so forth. + pools []sync.Pool + maxBufferSize int +} + +func New(maxBufferSize int) *Pool { + bp := &Pool{} + for size := 1; ; size *= 2 { + size := size + bp.pools = append(bp.pools, sync.Pool{ + New: func() interface{} { + return make([]byte, size) + }, + }) + if size >= maxBufferSize { + break + } + } + return bp +} + +// Get returns a byte slice of at least the specified length. +// +// CAUTION: In most cases the returned slice will have a greater than requested length. Don't use this package if you +// need buffers of exact lengths. +func (bp *Pool) Get(length int64) []byte { + // Calculate the smallest power of 2 exponent x where 2^x >= length + idx := bits.Len64(uint64(length - 1)) + if length == 0 { + idx = 0 + } + if idx >= len(bp.pools) { + idx = len(bp.pools) - 1 + } + return bp.pools[idx].Get().([]byte) +} + +// Put returns a byte slice back into the pool. +func (bp *Pool) Put(buf []byte) { + // Calculate the largest power of 2 exponent x where 2^x <= length + idx := bits.Len64(uint64(len(buf))) - 1 + if idx < 0 { + return + } + if idx >= len(bp.pools) { + idx = len(bp.pools) - 1 + } + bp.pools[idx].Put(buf) +} diff --git a/server/util/bytebufferpool/bytebufferpool_test.go b/server/util/bytebufferpool/bytebufferpool_test.go new file mode 100644 index 00000000000..115c8f63625 --- /dev/null +++ b/server/util/bytebufferpool/bytebufferpool_test.go @@ -0,0 +1,42 @@ +package bytebufferpool_test + +import ( + "testing" + + "github.com/buildbuddy-io/buildbuddy/server/util/bytebufferpool" + "github.com/stretchr/testify/assert" +) + +func TestBufferSize(t *testing.T) { + bp := bytebufferpool.New(1024) + + type test struct { + dataSize int64 + wantBufSize int + } + for _, testCase := range []test{ + {dataSize: 0, wantBufSize: 1}, + {dataSize: 1, wantBufSize: 1}, + {dataSize: 8, wantBufSize: 8}, + {dataSize: 12, wantBufSize: 16}, + {dataSize: 15, wantBufSize: 16}, + {dataSize: 16, wantBufSize: 16}, + {dataSize: 17, wantBufSize: 32}, + } { + assert.EqualValues(t, testCase.wantBufSize, len(bp.Get(testCase.dataSize)), "incorrect buffer len for length %d", testCase.dataSize) + assert.EqualValues(t, testCase.wantBufSize, cap(bp.Get(testCase.dataSize)), "incorrect buffer cap for data of length %d", testCase.dataSize) + } +} + +func TestReuse(t *testing.T) { + bp := bytebufferpool.New(1024) + + for i := 1; i < 20; i++ { + bp.Put(make([]byte, i)) + } + + for i := 1; i < 30; i++ { + buf := bp.Get(int64(i)) + assert.GreaterOrEqual(t, len(buf), i, "buffer for length %d did not have sufficient length", i) + } +}