From fa2e166cf70a1fba1e7b0c352e7301626b9b7e33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20Egelund-M=C3=BCller?= Date: Fri, 15 Dec 2023 12:29:03 +0100 Subject: [PATCH] Runtime: Refactor conn cache to contain and detect hanging opens/closes (#3666) * Runtime: Refactor conn cache to contain and detect hanging opens/closes * Extract connection cache to pkg + use a singleflight * Add tests * Make tests pass * increase test sleeps for clsoe * Fix various race conditions * Integrate singleflight with conncache's mutex * Increase timeouts * Better comments * Prevent deadlock when closing while opening * Address review comments * Remove redundant var --- runtime/connection_cache.go | 344 ++++-------------- runtime/connection_cache_test.go | 417 ---------------------- runtime/connections.go | 4 +- runtime/drivers/duckdb/duckdb.go | 2 +- runtime/pkg/conncache/conncache.go | 450 ++++++++++++++++++++++++ runtime/pkg/conncache/conncache_test.go | 268 ++++++++++++++ runtime/registry.go | 2 +- runtime/registry_test.go | 9 +- runtime/runtime.go | 7 +- 9 files changed, 806 insertions(+), 697 deletions(-) delete mode 100644 runtime/connection_cache_test.go create mode 100644 runtime/pkg/conncache/conncache.go create mode 100644 runtime/pkg/conncache/conncache_test.go diff --git a/runtime/connection_cache.go b/runtime/connection_cache.go index 882150d4ae3..e926b6db2e8 100644 --- a/runtime/connection_cache.go +++ b/runtime/connection_cache.go @@ -6,246 +6,100 @@ import ( "fmt" "slices" "strings" - "sync" "time" - "github.com/hashicorp/golang-lru/simplelru" "github.com/rilldata/rill/runtime/drivers" - "github.com/rilldata/rill/runtime/pkg/activity" + "github.com/rilldata/rill/runtime/pkg/conncache" "github.com/rilldata/rill/runtime/pkg/observability" + "go.opentelemetry.io/otel/metric" "go.uber.org/zap" "golang.org/x/exp/maps" ) -var errConnectionCacheClosed = errors.New("connectionCache: closed") - -const migrateTimeout = 2 * time.Minute - -// connectionCache is a thread-safe cache for open connections. -// Connections should preferably be opened only via the connection cache. -// -// TODO: It opens connections async, but it will close them sync when evicted. If a handle's close hangs, this can block the cache. -// We should move the closing to the background. However, it must then handle the case of trying to re-open a connection that's currently closing in the background. -type connectionCache struct { - size int - runtime *Runtime - logger *zap.Logger - activity activity.Client - closed bool - migrateCtx context.Context // ctx used for connection migrations - migrateCtxCancel context.CancelFunc // cancel all running migrations - lock sync.Mutex - acquired map[string]*connWithRef // items with non-zero references (in use) which should not be evicted - lru *simplelru.LRU // items with no references (opened, but not in use) ready for eviction -} +var ( + connCacheOpens = observability.Must(meter.Int64Counter("connnection_cache.opens")) + connCacheCloses = observability.Must(meter.Int64Counter("connnection_cache.closes")) + connCacheSizeTotal = observability.Must(meter.Int64UpDownCounter("connnection_cache.size_total")) + connCacheSizeLRU = observability.Must(meter.Int64UpDownCounter("connnection_cache.size_lru")) + connCacheOpenLatencyMS = observability.Must(meter.Int64Histogram("connnection_cache.open_latency", metric.WithUnit("ms"))) + connCacheCloseLatencyMS = observability.Must(meter.Int64Histogram("connnection_cache.close_latency", metric.WithUnit("ms"))) +) -type connWithRef struct { +type cachedConnectionConfig struct { instanceID string - handle drivers.Handle - err error - refs int - ready chan struct{} + driver string + shared bool + config map[string]any } -func newConnectionCache(size int, logger *zap.Logger, rt *Runtime, ac activity.Client) *connectionCache { - // LRU cache that closes evicted connections - lru, err := simplelru.NewLRU(size, func(key interface{}, value interface{}) { - // Skip if the conn has refs, since the callback also gets called when transferring to acquired cache - conn := value.(*connWithRef) - if conn.refs != 0 { - return - } - if conn.handle != nil { - if err := conn.handle.Close(); err != nil { - logger.Error("failed closing cached connection", zap.String("key", key.(string)), zap.Error(err)) - } - } +// newConnectionCache returns a concurrency-safe cache for open connections. +// Connections should preferably be opened only via the connection cache. +// It's implementation handles issues such as concurrent open/close/eviction of a connection. +// It also monitors for hanging connections. +func (r *Runtime) newConnectionCache() conncache.Cache { + return conncache.New(conncache.Options{ + MaxIdleConnections: r.opts.ConnectionCacheSize, + OpenTimeout: 10 * time.Minute, + CloseTimeout: 10 * time.Minute, + CheckHangingInterval: time.Minute, + OpenFunc: func(ctx context.Context, cfg any) (conncache.Connection, error) { + x := cfg.(cachedConnectionConfig) + return r.openAndMigrate(ctx, x) + }, + KeyFunc: func(cfg any) string { + x := cfg.(cachedConnectionConfig) + return generateKey(x) + }, + HangingFunc: func(cfg any, open bool) { + x := cfg.(cachedConnectionConfig) + r.logger.Error("connection cache: connection has been working for too long", zap.String("instance_id", x.instanceID), zap.String("driver", x.driver), zap.Bool("open", open)) + }, + Metrics: conncache.Metrics{ + Opens: connCacheOpens, + Closes: connCacheCloses, + SizeTotal: connCacheSizeTotal, + SizeLRU: connCacheSizeLRU, + OpenLatencyMS: connCacheOpenLatencyMS, + CloseLatencyMS: connCacheCloseLatencyMS, + }, }) - if err != nil { - panic(err) - } - - ctx, cancel := context.WithCancel(context.Background()) - return &connectionCache{ - size: size, - runtime: rt, - logger: logger, - activity: ac, - migrateCtx: ctx, - migrateCtxCancel: cancel, - acquired: make(map[string]*connWithRef), - lru: lru, - } -} - -func (c *connectionCache) Close() error { - c.lock.Lock() - defer c.lock.Unlock() - - if c.closed { - return errConnectionCacheClosed - } - c.closed = true - - // Cancel currently running migrations - c.migrateCtxCancel() - - var firstErr error - for _, key := range c.lru.Keys() { - val, ok := c.lru.Get(key) - if !ok { - continue - } - conn := val.(*connWithRef) - if conn.handle == nil { - continue - } - err := conn.handle.Close() - if err != nil { - c.logger.Error("failed closing cached connection", zap.Error(err)) - if firstErr == nil { - firstErr = err - } - } - } - - for _, value := range c.acquired { - if value.handle == nil { - continue - } - err := value.handle.Close() - if err != nil { - c.logger.Error("failed closing cached connection", zap.Error(err)) - if firstErr == nil { - firstErr = err - } - } - } - - return firstErr } -func (c *connectionCache) get(ctx context.Context, instanceID, driver string, config map[string]any, shared bool) (drivers.Handle, func(), error) { - var key string - if shared { - // not using instanceID to ensure all instances share the same handle - key = driver + generateKey(config) - } else { - key = instanceID + driver + generateKey(config) - } - - c.lock.Lock() - if c.closed { - c.lock.Unlock() - return nil, nil, errConnectionCacheClosed - } - - // Get conn from caches - conn, ok := c.acquired[key] - if ok { - conn.refs++ - } else { - var val any - val, ok = c.lru.Get(key) - if ok { - // Conn was found in LRU - move to acquired cache - conn = val.(*connWithRef) - conn.refs++ // NOTE: Must increment before call to c.lru.remove to avoid closing the conn - c.lru.Remove(key) - c.acquired[key] = conn - } - } - - // Cached conn not found, open a new one - if !ok { - conn = &connWithRef{ - instanceID: instanceID, - refs: 1, // Since refs is assumed to already have been incremented when checking conn.ready - ready: make(chan struct{}), - } - c.acquired[key] = conn - - if len(c.acquired)+c.lru.Len() > c.size { - c.logger.Warn("number of connections acquired and in LRU exceed total configured size", zap.Int("acquired", len(c.acquired)), zap.Int("lru", c.lru.Len())) - } - - // Open and migrate the connection in a separate goroutine (outside lock). - // Incrementing ref and releasing the conn for this operation separately to cover the case where all waiting goroutines are cancelled before the migration completes. - conn.refs++ - go func() { - handle, err := c.openAndMigrate(c.migrateCtx, instanceID, driver, shared, config) - c.lock.Lock() - conn.handle = handle - conn.err = err - c.releaseConn(key, conn) - wasClosed := c.closed - c.lock.Unlock() - close(conn.ready) - - // The cache might have been closed while the connection was being opened. - // Since we acquired the lock, the close will have already been completed, so we need to close the connection here. - if wasClosed && handle != nil { - _ = handle.Close() - } - }() - } - - // We can now release the lock and wait for the connection to be ready (it might already be) - c.lock.Unlock() - - // Wait for connection to be ready or context to be cancelled - var err error - select { - case <-conn.ready: - case <-ctx.Done(): - err = ctx.Err() // Will always be non-nil, ensuring releaseConn is called - } - - // Lock again for accessing conn - c.lock.Lock() - defer c.lock.Unlock() - - if err == nil { - err = conn.err +// getConnection returns a cached connection for the given driver configuration. +func (r *Runtime) getConnection(ctx context.Context, instanceID, driver string, config map[string]any, shared bool) (drivers.Handle, func(), error) { + cfg := cachedConnectionConfig{ + instanceID: instanceID, + driver: driver, + shared: shared, + config: config, } + handle, release, err := r.connCache.Acquire(ctx, cfg) if err != nil { - c.releaseConn(key, conn) return nil, nil, err } - release := func() { - c.lock.Lock() - c.releaseConn(key, conn) - c.lock.Unlock() - } - - return conn.handle, release, nil + return handle.(drivers.Handle), release, nil } -func (c *connectionCache) releaseConn(key string, conn *connWithRef) { - conn.refs-- - if conn.refs == 0 { - // No longer referenced. Move from acquired to LRU. - if !c.closed { - delete(c.acquired, key) - c.lru.Add(key, conn) - } - } +// evictInstanceConnections evicts all connections for the given instance. +func (r *Runtime) evictInstanceConnections(instanceID string) { + r.connCache.EvictWhere(func(cfg any) bool { + x := cfg.(cachedConnectionConfig) + return x.instanceID == instanceID + }) } -func (c *connectionCache) openAndMigrate(ctx context.Context, instanceID, driver string, shared bool, config map[string]any) (drivers.Handle, error) { - logger := c.logger - if instanceID != "default" { - logger = c.logger.With(zap.String("instance_id", instanceID), zap.String("driver", driver)) +// openAndMigrate opens a connection and migrates it. +func (r *Runtime) openAndMigrate(ctx context.Context, cfg cachedConnectionConfig) (drivers.Handle, error) { + logger := r.logger + if cfg.instanceID != "default" { + logger = r.logger.With(zap.String("instance_id", cfg.instanceID), zap.String("driver", cfg.driver)) } - ctx, cancel := context.WithTimeout(ctx, migrateTimeout) - defer cancel() - - activityClient := c.activity - if !shared { - inst, err := c.runtime.Instance(ctx, instanceID) + activityClient := r.activity + if !cfg.shared { + inst, err := r.Instance(ctx, cfg.instanceID) if err != nil { return nil, err } @@ -256,9 +110,9 @@ func (c *connectionCache) openAndMigrate(ctx context.Context, instanceID, driver } } - handle, err := drivers.Open(driver, config, shared, activityClient, logger) + handle, err := drivers.Open(cfg.driver, cfg.config, cfg.shared, activityClient, logger) if err == nil && ctx.Err() != nil { - err = fmt.Errorf("timed out while opening driver %q", driver) + err = fmt.Errorf("timed out while opening driver %q", cfg.driver) } if err != nil { return nil, err @@ -268,71 +122,23 @@ func (c *connectionCache) openAndMigrate(ctx context.Context, instanceID, driver if err != nil { handle.Close() if errors.Is(err, ctx.Err()) { - err = fmt.Errorf("timed out while migrating driver %q: %w", driver, err) + err = fmt.Errorf("timed out while migrating driver %q: %w", cfg.driver, err) } return nil, err } return handle, nil } -// evictAll closes all connections for an instance. -func (c *connectionCache) evictAll(ctx context.Context, instanceID string) { - c.lock.Lock() - defer c.lock.Unlock() - - if c.closed { - return - } - - for key, conn := range c.acquired { - if conn.instanceID != instanceID { - continue - } - - if conn.handle != nil { - err := conn.handle.Close() - if err != nil { - c.logger.Error("connection cache: failed to close cached connection", zap.Error(err), zap.String("instance", instanceID), observability.ZapCtx(ctx)) - } - conn.handle = nil - conn.err = fmt.Errorf("connection evicted") // Defensive, should never be accessed - } - - delete(c.acquired, key) - } - - for _, key := range c.lru.Keys() { - connT, ok := c.lru.Get(key) - if !ok { - panic("connection cache: key not found in LRU") - } - conn := connT.(*connWithRef) - - if conn.instanceID != instanceID { - continue - } - - if conn.handle != nil { - err := conn.handle.Close() - if err != nil { - c.logger.Error("connection cache: failed to close cached connection", zap.Error(err), zap.String("instance", instanceID), observability.ZapCtx(ctx)) - } - conn.handle = nil - conn.err = fmt.Errorf("connection evicted") // Defensive, should never be accessed - } - - c.lru.Remove(key) - } -} - -func generateKey(m map[string]any) string { +func generateKey(cfg cachedConnectionConfig) string { sb := strings.Builder{} - keys := maps.Keys(m) + sb.WriteString(cfg.instanceID) // Empty if cfg.shared + sb.WriteString(cfg.driver) + keys := maps.Keys(cfg.config) slices.Sort(keys) for _, key := range keys { sb.WriteString(key) sb.WriteString(":") - sb.WriteString(fmt.Sprint(m[key])) + sb.WriteString(fmt.Sprint(cfg.config[key])) sb.WriteString(" ") } return sb.String() diff --git a/runtime/connection_cache_test.go b/runtime/connection_cache_test.go deleted file mode 100644 index 299021e1749..00000000000 --- a/runtime/connection_cache_test.go +++ /dev/null @@ -1,417 +0,0 @@ -package runtime - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "testing" - "time" - - runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" - "github.com/rilldata/rill/runtime/drivers" - _ "github.com/rilldata/rill/runtime/drivers/sqlite" - "github.com/rilldata/rill/runtime/pkg/activity" - "github.com/stretchr/testify/require" - "go.uber.org/zap" -) - -func TestConnectionCache(t *testing.T) { - ctx := context.Background() - id := "default" - - rt := newTestRuntimeWithInst(t) - c := newConnectionCache(10, zap.NewNop(), rt, activity.NewNoopClient()) - conn1, release, err := c.get(ctx, id, "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - release() - require.NotNil(t, conn1) - - conn2, release, err := c.get(ctx, id, "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - release() - require.NotNil(t, conn2) - - inst := &drivers.Instance{ - ID: "default1", - OLAPConnector: "duckdb", - RepoConnector: "repo", - CatalogConnector: "catalog", - Connectors: []*runtimev1.Connector{ - { - Type: "file", - Name: "repo", - Config: map[string]string{"dsn": ""}, - }, - { - Type: "duckdb", - Name: "duckdb", - Config: map[string]string{"dsn": ""}, - }, - { - Type: "sqlite", - Name: "catalog", - Config: map[string]string{"dsn": "file:rill?mode=memory&cache=shared"}, - }, - }, - } - require.NoError(t, rt.CreateInstance(context.Background(), inst)) - - conn3, release, err := c.get(ctx, "default1", "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - release() - require.NotNil(t, conn3) - - require.True(t, conn1 == conn2) - require.False(t, conn2 == conn3) -} - -func TestConnectionCacheWithAllShared(t *testing.T) { - ctx := context.Background() - id := "default" - - c := newConnectionCache(1, zap.NewNop(), newTestRuntimeWithInst(t), activity.NewNoopClient()) - conn1, release, err := c.get(ctx, id, "sqlite", map[string]any{"dsn": ":memory:"}, true) - require.NoError(t, err) - require.NotNil(t, conn1) - defer release() - - conn2, release, err := c.get(ctx, id, "sqlite", map[string]any{"dsn": ":memory:"}, true) - require.NoError(t, err) - require.NotNil(t, conn2) - defer release() - - conn3, release, err := c.get(ctx, "default", "sqlite", map[string]any{"dsn": ":memory:"}, true) - require.NoError(t, err) - require.NotNil(t, conn3) - defer release() - - require.True(t, conn1 == conn2) - require.True(t, conn2 == conn3) - require.Equal(t, 1, len(c.acquired)) - require.Equal(t, 0, c.lru.Len()) -} - -func TestConnectionCacheWithAllOpen(t *testing.T) { - ctx := context.Background() - - rt := newTestRuntimeWithInst(t) - c := newConnectionCache(1, zap.NewNop(), rt, activity.NewNoopClient()) - conn1, r1, err := c.get(ctx, "default", "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - require.NotNil(t, conn1) - - createInstance(t, rt, "default1") - conn2, r2, err := c.get(ctx, "default1", "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - require.NotNil(t, conn2) - - createInstance(t, rt, "default2") - conn3, r3, err := c.get(ctx, "default2", "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - require.NotNil(t, conn3) - - require.Equal(t, 3, len(c.acquired)) - require.Equal(t, 0, c.lru.Len()) - // release all connections - r1() - r2() - r3() - require.Equal(t, 0, len(c.acquired)) - require.Equal(t, 1, c.lru.Len()) - _, val, _ := c.lru.GetOldest() - require.True(t, conn3 == val.(*connWithRef).handle) -} - -func TestConnectionCacheParallel(t *testing.T) { - ctx := context.Background() - - rt := newTestRuntimeWithInst(t) - c := newConnectionCache(5, zap.NewNop(), rt, activity.NewNoopClient()) - defer c.Close() - - var wg sync.WaitGroup - wg.Add(30) - // open 10 connections and do not release - go func() { - for i := 0; i < 10; i++ { - j := i - go func() { - defer wg.Done() - id := fmt.Sprintf("default%v", 100+j) - createInstance(t, rt, id) - conn, _, err := c.get(ctx, id, "sqlite", map[string]any{"dsn": ":memory:"}, false) - require.NoError(t, err) - require.NotNil(t, conn) - time.Sleep(100 * time.Millisecond) - }() - } - }() - - // open 20 connections and release - for i := 0; i < 20; i++ { - j := i - go func() { - defer wg.Done() - id := fmt.Sprintf("default%v", 200+j) - createInstance(t, rt, id) - conn, r, err := c.get(ctx, id, "sqlite", map[string]any{"dsn": ":memory:"}, false) - defer r() - require.NoError(t, err) - require.NotNil(t, conn) - time.Sleep(100 * time.Millisecond) - }() - } - wg.Wait() - - // 10 connections were not released so should be present in in-use cache - require.Equal(t, 10, len(c.acquired)) - // 20 connections were released so 15 should be evicted - require.Equal(t, 5, c.lru.Len()) -} - -func TestConnectionCacheMultipleConfigs(t *testing.T) { - ctx := context.Background() - - c := newConnectionCache(10, zap.NewNop(), newTestRuntimeWithInst(t), activity.NewNoopClient()) - defer c.Close() - conn1, r1, err := c.get(ctx, "default", "sqlite", map[string]any{"dsn": ":memory:", "host": "localhost:8080", "allow_host_access": "true"}, true) - require.NoError(t, err) - require.NotNil(t, conn1) - - conn2, r2, err := c.get(ctx, "default", "sqlite", map[string]any{"dsn": ":memory:", "host": "localhost:8080", "allow_host_access": "true"}, true) - require.NoError(t, err) - require.NotNil(t, conn2) - - conn3, r3, err := c.get(ctx, "default", "sqlite", map[string]any{"dsn": ":memory:", "host": "localhost:8080", "allow_host_access": "true"}, true) - require.NoError(t, err) - require.NotNil(t, conn3) - - require.Equal(t, 1, len(c.acquired)) - require.Equal(t, 0, c.lru.Len()) - // release all connections - r1() - r2() - r3() - require.Equal(t, 0, len(c.acquired)) - require.Equal(t, 1, c.lru.Len()) -} - -func TestConnectionCacheParallelCalls(t *testing.T) { - ctx := context.Background() - - m := &mockDriver{} - drivers.Register("mock_driver", m) - defer func() { - delete(drivers.Drivers, "mock_driver") - }() - - rt := newTestRuntimeWithInst(t) - defer rt.Close() - - c := newConnectionCache(10, zap.NewNop(), rt, activity.NewNoopClient()) - defer c.Close() - - var wg sync.WaitGroup - wg.Add(10) - // open 10 connections and verify no error - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - conn, _, err := c.get(ctx, "default", "mock_driver", map[string]any{"sleep": int64(100)}, false) - require.NoError(t, err) - require.NotNil(t, conn) - }() - } - wg.Wait() - - require.Equal(t, int32(1), m.opened.Load()) - require.Equal(t, 1, len(c.acquired)) -} - -func TestConnectionCacheBlockingCalls(t *testing.T) { - ctx := context.Background() - - m := &mockDriver{} - drivers.Register("mock_driver", m) - defer func() { - delete(drivers.Drivers, "mock_driver") - }() - - rt := newTestRuntimeWithInst(t) - defer rt.Close() - - c := newConnectionCache(10, zap.NewNop(), rt, activity.NewNoopClient()) - defer c.Close() - - var wg sync.WaitGroup - wg.Add(12) - // open 1 slow connection - go func() { - defer wg.Done() - conn, _, err := c.get(ctx, "default", "mock_driver", map[string]any{"sleep": int64(1000)}, false) - require.NoError(t, err) - require.NotNil(t, conn) - }() - - // open 10 fast different connections(takes 10-20 ms to open) and verify not blocked - for i := 0; i < 10; i++ { - j := i - go func() { - defer wg.Done() - conn, _, err := c.get(ctx, "default", "mock_driver", map[string]any{"sleep": int64(j + 10)}, false) - require.NoError(t, err) - require.NotNil(t, conn) - }() - } - - // verify that after 200 ms 11 connections have been opened - go func() { - time.Sleep(200 * time.Millisecond) - wg.Done() - }() - wg.Wait() - - require.Equal(t, int32(11), m.opened.Load()) -} - -type mockDriver struct { - opened atomic.Int32 -} - -// Drop implements drivers.Driver. -func (*mockDriver) Drop(config map[string]any, logger *zap.Logger) error { - panic("unimplemented") -} - -// HasAnonymousSourceAccess implements drivers.Driver. -func (*mockDriver) HasAnonymousSourceAccess(ctx context.Context, src map[string]any, logger *zap.Logger) (bool, error) { - panic("unimplemented") -} - -func (*mockDriver) TertiarySourceConnectors(ctx context.Context, src map[string]any, logger *zap.Logger) ([]string, error) { - return nil, nil -} - -// Open implements drivers.Driver. -func (m *mockDriver) Open(config map[string]any, shared bool, client activity.Client, logger *zap.Logger) (drivers.Handle, error) { - m.opened.Add(1) - sleep := config["sleep"].(int64) - time.Sleep(time.Duration(sleep) * time.Millisecond) - return &mockHandle{}, nil -} - -// Spec implements drivers.Driver. -func (*mockDriver) Spec() drivers.Spec { - panic("unimplemented") -} - -var _ drivers.Driver = &mockDriver{} - -type mockHandle struct { -} - -// AsCatalogStore implements drivers.Handle. -func (*mockHandle) AsCatalogStore(instanceID string) (drivers.CatalogStore, bool) { - panic("unimplemented") -} - -// AsFileStore implements drivers.Handle. -func (*mockHandle) AsFileStore() (drivers.FileStore, bool) { - panic("unimplemented") -} - -// AsOLAP implements drivers.Handle. -func (*mockHandle) AsOLAP(instanceID string) (drivers.OLAPStore, bool) { - panic("unimplemented") -} - -// AsObjectStore implements drivers.Handle. -func (*mockHandle) AsObjectStore() (drivers.ObjectStore, bool) { - panic("unimplemented") -} - -// AsRegistry implements drivers.Handle. -func (*mockHandle) AsRegistry() (drivers.RegistryStore, bool) { - panic("unimplemented") -} - -// AsRepoStore implements drivers.Handle. -func (*mockHandle) AsRepoStore(instanceID string) (drivers.RepoStore, bool) { - panic("unimplemented") -} - -// AsAdmin implements drivers.Handle. -func (*mockHandle) AsAdmin(instanceID string) (drivers.AdminService, bool) { - panic("unimplemented") -} - -// AsSQLStore implements drivers.Handle. -func (*mockHandle) AsSQLStore() (drivers.SQLStore, bool) { - panic("unimplemented") -} - -// AsTransporter implements drivers.Handle. -func (*mockHandle) AsTransporter(from drivers.Handle, to drivers.Handle) (drivers.Transporter, bool) { - panic("unimplemented") -} - -// Close implements drivers.Handle. -func (*mockHandle) Close() error { - return nil -} - -// Config implements drivers.Handle. -func (*mockHandle) Config() map[string]any { - panic("unimplemented") -} - -// Driver implements drivers.Handle. -func (*mockHandle) Driver() string { - panic("unimplemented") -} - -// Migrate implements drivers.Handle. -func (*mockHandle) Migrate(ctx context.Context) error { - return nil -} - -// MigrationStatus implements drivers.Handle. -func (*mockHandle) MigrationStatus(ctx context.Context) (current int, desired int, err error) { - panic("unimplemented") -} - -var _ drivers.Handle = &mockHandle{} - -func newTestRuntimeWithInst(t *testing.T) *Runtime { - rt := newTestRuntime(t) - createInstance(t, rt, "default") - return rt -} - -func createInstance(t *testing.T, rt *Runtime, instanceId string) { - inst := &drivers.Instance{ - ID: instanceId, - OLAPConnector: "duckdb", - RepoConnector: "repo", - CatalogConnector: "catalog", - Connectors: []*runtimev1.Connector{ - { - Type: "file", - Name: "repo", - Config: map[string]string{"dsn": ""}, - }, - { - Type: "duckdb", - Name: "duckdb", - Config: map[string]string{"dsn": ""}, - }, - { - Type: "sqlite", - Name: "catalog", - Config: map[string]string{"dsn": "file:rill?mode=memory&cache=shared"}, - }, - }, - } - require.NoError(t, rt.CreateInstance(context.Background(), inst)) -} diff --git a/runtime/connections.go b/runtime/connections.go index ba939063c07..b72c3618551 100644 --- a/runtime/connections.go +++ b/runtime/connections.go @@ -19,7 +19,7 @@ func (r *Runtime) AcquireSystemHandle(ctx context.Context, connector string) (dr cfg[strings.ToLower(k)] = v } cfg["allow_host_access"] = r.opts.AllowHostAccess - return r.connCache.get(ctx, "", c.Type, cfg, true) + return r.getConnection(ctx, "", c.Type, cfg, true) } } return nil, nil, fmt.Errorf("connector %s doesn't exist", connector) @@ -36,7 +36,7 @@ func (r *Runtime) AcquireHandle(ctx context.Context, instanceID, connector strin // So we take this moment to make sure the ctx gets checked for cancellation at least every once in a while. return nil, nil, ctx.Err() } - return r.connCache.get(ctx, instanceID, driver, cfg, false) + return r.getConnection(ctx, instanceID, driver, cfg, false) } func (r *Runtime) Repo(ctx context.Context, instanceID string) (drivers.RepoStore, func(), error) { diff --git a/runtime/drivers/duckdb/duckdb.go b/runtime/drivers/duckdb/duckdb.go index f2a3737613e..78f57909285 100644 --- a/runtime/drivers/duckdb/duckdb.go +++ b/runtime/drivers/duckdb/duckdb.go @@ -813,7 +813,7 @@ func (c *connection) periodicallyCheckConnDurations(d time.Duration) { c.connTimesMu.Lock() for connID, connTime := range c.connTimes { if time.Since(connTime) > maxAcquiredConnDuration { - c.logger.Error("duckdb: a connection has been held for more longer than the maximum allowed duration", zap.Int("conn_id", connID), zap.Duration("duration", time.Since(connTime))) + c.logger.Error("duckdb: a connection has been held for longer than the maximum allowed duration", zap.Int("conn_id", connID), zap.Duration("duration", time.Since(connTime))) } } c.connTimesMu.Unlock() diff --git a/runtime/pkg/conncache/conncache.go b/runtime/pkg/conncache/conncache.go new file mode 100644 index 00000000000..b746f46a163 --- /dev/null +++ b/runtime/pkg/conncache/conncache.go @@ -0,0 +1,450 @@ +package conncache + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/hashicorp/golang-lru/simplelru" + "go.opentelemetry.io/otel/metric" +) + +// Cache is a concurrency-safe cache of stateful connection objects. +// It differs from a connection pool in that it's designed for caching heterogenous connections. +// The cache will at most open one connection per key, even under concurrent access. +// The cache automatically evicts connections that are not in use ("acquired") using a least-recently-used policy. +type Cache interface { + // Acquire retrieves or opens a connection for the given config. The returned ReleaseFunc must be called when the connection is no longer needed. + // While a connection is acquired, it will not be closed unless EvictWhere or Close is called. + // If Acquire is called while the underlying connection is being evicted, it will wait for the close to complete and then open a new connection. + // If opening the connection fails, Acquire may return the error on subsequent calls without trying to open again until the entry is evicted. + Acquire(ctx context.Context, cfg any) (Connection, ReleaseFunc, error) + + // EvictWhere closes the connections that match the predicate. + // It immediately starts closing the connections, even those that are currently acquired. + // It returns quickly and does not wait for connections to finish closing in the background. + EvictWhere(predicate func(cfg any) bool) + + // Close closes all open connections and prevents new connections from being acquired. + // It returns when all cached connections have been closed or when the provided ctx is cancelled. + Close(ctx context.Context) error +} + +// Connection is a connection that may be cached. +type Connection interface { + Close() error +} + +// ReleaseFunc is a function that must be called when an acquired connection is no longer needed. +type ReleaseFunc func() + +// Options configures a new connection cache. +type Options struct { + // MaxIdleConnections is the maximum number of non-acquired connections that will be kept open. + MaxIdleConnections int + // OpenTimeout is the maximum amount of time to wait for a connection to open. + OpenTimeout time.Duration + // CloseTimeout is the maximum amount of time to wait for a connection to close. + CloseTimeout time.Duration + // CheckHangingInterval is the interval at which to check for hanging open/close calls. + CheckHangingInterval time.Duration + // OpenFunc opens a connection. + OpenFunc func(ctx context.Context, cfg any) (Connection, error) + // KeyFunc computes a comparable key for a connection config. + KeyFunc func(cfg any) string + // HangingFunc is called when an open or close exceeds its timeout and does not respond to context cancellation. + HangingFunc func(cfg any, open bool) + // Metrics are optional instruments for observability. + Metrics Metrics +} + +// Metrics are optional instruments for observability. If an instrument is nil, it will not be collected. +type Metrics struct { + Opens metric.Int64Counter + Closes metric.Int64Counter + SizeTotal metric.Int64UpDownCounter + SizeLRU metric.Int64UpDownCounter + OpenLatencyMS metric.Int64Histogram + CloseLatencyMS metric.Int64Histogram +} + +var _ Cache = (*cacheImpl)(nil) + +// cacheImpl implements Cache. Implementation notes: +// - It uses an LRU to pool unused connections and eventually close them. +// - It leverages a singleflight pattern to ensure at most one open/close action runs against a connection at a time. +// - It directly implements a singleflight (instead of using a library) because it needs to use the same mutex for the singleflight and the map/LRU to avoid race conditions. +// - An entry will only have entryStatusOpening or entryStatusClosing if a singleflight call is currently running for it. +// - Any code that keeps a reference to an entry after the mutex is released must call retainEntry/releaseEntry. +// - If the ctx for an open call is cancelled, the entry will continue opening in the background (and will be put in the LRU). +// - If attempting to open a closing entry, or close an opening entry, we wait for the singleflight to complete and then retry once. To avoid infinite loops, we don't retry more than once. +type cacheImpl struct { + opts Options + closed bool + mu sync.Mutex + entries map[string]*entry + lru *simplelru.LRU + singleflight map[string]chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +type entry struct { + cfg any + refs int + status entryStatus + since time.Time + closeAfterOpening bool + handle Connection + err error +} + +type entryStatus int + +const ( + entryStatusUnspecified entryStatus = iota + entryStatusOpening + entryStatusOpen // Also used for cases where open errored (i.e. entry.err != nil) + entryStatusClosing + entryStatusClosed +) + +func New(opts Options) Cache { + ctx, cancel := context.WithCancel(context.Background()) + c := &cacheImpl{ + opts: opts, + entries: make(map[string]*entry), + singleflight: make(map[string]chan struct{}), + ctx: ctx, + cancel: cancel, + } + + var err error + c.lru, err = simplelru.NewLRU(opts.MaxIdleConnections, c.lruEvictionHandler) + if err != nil { + panic(err) + } + + if opts.CheckHangingInterval != 0 { + go c.periodicallyCheckHangingConnections() + } + + return c +} + +func (c *cacheImpl) Acquire(ctx context.Context, cfg any) (Connection, ReleaseFunc, error) { + k := c.opts.KeyFunc(cfg) + + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, nil, errors.New("conncache: closed") + } + + e, ok := c.entries[k] + if !ok { + e = &entry{cfg: cfg, since: time.Now()} + c.entries[k] = e + if c.opts.Metrics.SizeTotal != nil { + c.opts.Metrics.SizeTotal.Add(c.ctx, 1) + } + } + + c.retainEntry(k, e) + + if e.status == entryStatusOpen { + defer c.mu.Unlock() + if e.err != nil { + c.releaseEntry(k, e) + return nil, nil, e.err + } + return e.handle, c.releaseFunc(k, e), nil + } + + ch, ok := c.singleflight[k] + + if ok && e.status == entryStatusClosing { + c.mu.Unlock() + select { + case <-ch: + case <-ctx.Done(): + c.mu.Lock() + c.releaseEntry(k, e) + c.mu.Unlock() + return nil, nil, ctx.Err() + } + c.mu.Lock() + + // Since we released the lock, need to check c.closed and e.status again. + if c.closed { + c.releaseEntry(k, e) + c.mu.Unlock() + return nil, nil, errors.New("conncache: closed") + } + + if e.status == entryStatusOpen { + defer c.mu.Unlock() + if e.err != nil { + c.releaseEntry(k, e) + return nil, nil, e.err + } + return e.handle, c.releaseFunc(k, e), nil + } + + ch, ok = c.singleflight[k] + } + + if !ok { + c.retainEntry(k, e) // Retain again to count the goroutine's reference independently (in case ctx is cancelled while the Open continues in the background) + + ch = make(chan struct{}) + c.singleflight[k] = ch + + e.status = entryStatusOpening + e.since = time.Now() + e.handle = nil + e.err = nil + + go func() { + start := time.Now() + var handle Connection + var err error + if c.opts.OpenTimeout == 0 { + handle, err = c.opts.OpenFunc(c.ctx, cfg) + } else { + ctx, cancel := context.WithTimeout(c.ctx, c.opts.OpenTimeout) + handle, err = c.opts.OpenFunc(ctx, cfg) + cancel() + } + + if c.opts.Metrics.Opens != nil { + c.opts.Metrics.Opens.Add(c.ctx, 1) + } + if c.opts.Metrics.OpenLatencyMS != nil { + c.opts.Metrics.OpenLatencyMS.Record(c.ctx, time.Since(start).Milliseconds()) + } + + c.mu.Lock() + defer c.mu.Unlock() + + e.status = entryStatusOpen + e.since = time.Now() + e.handle = handle + e.err = err + + delete(c.singleflight, k) + close(ch) + + if e.closeAfterOpening { + e.closeAfterOpening = false + c.beginClose(k, e) + } + + c.releaseEntry(k, e) + }() + } + + c.mu.Unlock() + + select { + case <-ch: + case <-ctx.Done(): + c.mu.Lock() + c.releaseEntry(k, e) + c.mu.Unlock() + return nil, nil, ctx.Err() + } + + c.mu.Lock() + defer c.mu.Unlock() + + if e.status != entryStatusOpen { + c.releaseEntry(k, e) + return nil, nil, errors.New("conncache: connection was immediately closed after being opened") + } + + if e.err != nil { + c.releaseEntry(k, e) + return nil, nil, e.err + } + + return e.handle, c.releaseFunc(k, e), nil +} + +func (c *cacheImpl) EvictWhere(predicate func(cfg any) bool) { + c.mu.Lock() + defer c.mu.Unlock() + for k, e := range c.entries { + if predicate(e.cfg) { + c.beginClose(k, e) + } + } +} + +func (c *cacheImpl) Close(ctx context.Context) error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return errors.New("conncache: already closed") + } + c.closed = true + + c.cancel() + + for k, e := range c.entries { + c.beginClose(k, e) + } + + c.mu.Unlock() + + for { + c.mu.Lock() + var anyCh chan struct{} + for _, ch := range c.singleflight { + anyCh = ch + break + } + c.mu.Unlock() + + if anyCh == nil { + // all entries are closed, we can return + return nil + } + + select { + case <-anyCh: + // continue + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// beginClose must be called while c.mu is held. +func (c *cacheImpl) beginClose(k string, e *entry) { + if e.status == entryStatusClosing || e.status == entryStatusClosed { + return + } + + if e.status == entryStatusOpening { + e.closeAfterOpening = true + return + } + + c.retainEntry(k, e) + + ch, ok := c.singleflight[k] + if ok { + // Should never happen, but checking since it would be pretty bad. + panic(errors.New("conncache: singleflight exists for entry that is neither opening nor closing")) + } + ch = make(chan struct{}) + c.singleflight[k] = ch + + e.status = entryStatusClosing + e.since = time.Now() + + go func() { + start := time.Now() + var err error + if e.handle != nil { + err = e.handle.Close() + } + if err == nil { + err = errors.New("conncache: connection closed") + } + + if c.opts.Metrics.Closes != nil { + c.opts.Metrics.Closes.Add(c.ctx, 1) + } + if c.opts.Metrics.CloseLatencyMS != nil { + c.opts.Metrics.CloseLatencyMS.Record(c.ctx, time.Since(start).Milliseconds()) + } + + c.mu.Lock() + defer c.mu.Unlock() + + e.status = entryStatusClosed + e.since = time.Now() + e.handle = nil + e.err = err + + delete(c.singleflight, k) + close(ch) + + c.releaseEntry(k, e) + }() +} + +func (c *cacheImpl) lruEvictionHandler(key, value any) { + k := key.(string) + e := value.(*entry) + + // The callback also gets called when removing from LRU during acquisition. + // We use conn.refs != 0 to signal that its being acquired and should not be closed. + if e.refs == 0 { + c.beginClose(k, e) + } +} + +func (c *cacheImpl) retainEntry(key string, e *entry) { + e.refs++ + if e.refs == 1 { + // NOTE: lru.Remove is safe even if it's not in the LRU (should only happen if the entry is acquired for the first time) + _ = c.lru.Remove(key) + if c.opts.Metrics.SizeLRU != nil { + c.opts.Metrics.SizeLRU.Add(c.ctx, -1) + } + } +} + +func (c *cacheImpl) releaseEntry(key string, e *entry) { + e.refs-- + if e.refs == 0 { + // If open, keep entry and put in LRU. Else remove entirely. + if e.status != entryStatusClosing && e.status != entryStatusClosed { + c.lru.Add(key, e) + if c.opts.Metrics.SizeLRU != nil { + c.opts.Metrics.SizeLRU.Add(c.ctx, 1) + } + } else { + delete(c.entries, key) + if c.opts.Metrics.SizeTotal != nil { + c.opts.Metrics.SizeTotal.Add(c.ctx, -1) + } + } + } +} + +func (c *cacheImpl) releaseFunc(key string, e *entry) ReleaseFunc { + return func() { + c.mu.Lock() + c.releaseEntry(key, e) + c.mu.Unlock() + } +} + +func (c *cacheImpl) periodicallyCheckHangingConnections() { + ticker := time.NewTicker(c.opts.CheckHangingInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.mu.Lock() + for k := range c.singleflight { + e := c.entries[k] + if c.opts.OpenTimeout != 0 && e.status == entryStatusOpening && time.Since(e.since) > c.opts.OpenTimeout { + c.opts.HangingFunc(e.cfg, true) + } + if c.opts.CloseTimeout != 0 && e.status == entryStatusClosing && time.Since(e.since) > c.opts.CloseTimeout { + c.opts.HangingFunc(e.cfg, false) + } + } + c.mu.Unlock() + case <-c.ctx.Done(): + return + } + } +} diff --git a/runtime/pkg/conncache/conncache_test.go b/runtime/pkg/conncache/conncache_test.go new file mode 100644 index 00000000000..8c818830907 --- /dev/null +++ b/runtime/pkg/conncache/conncache_test.go @@ -0,0 +1,268 @@ +package conncache + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type mockConn struct { + cfg string + closeDelay time.Duration + closeCalled atomic.Bool +} + +func (c *mockConn) Close() error { + c.closeCalled.Store(true) + time.Sleep(c.closeDelay) + return nil +} + +func TestBasic(t *testing.T) { + opens := atomic.Int64{} + + c := New(Options{ + MaxIdleConnections: 2, + OpenFunc: func(ctx context.Context, cfg any) (Connection, error) { + opens.Add(1) + return &mockConn{cfg: cfg.(string)}, nil + }, + KeyFunc: func(cfg any) string { + return cfg.(string) + }, + }) + + // Get "foo" + m1, r1, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Equal(t, int64(1), opens.Load()) + + // Get "foo" again + m2, r2, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Equal(t, int64(1), opens.Load()) + + // Check that they're the same + require.Equal(t, m1, m2) + + // Release the "foo"s and get "foo" again, check it's the same + r1() + r2() + m3, r3, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Equal(t, int64(1), opens.Load()) + require.Equal(t, m1, m3) + r3() + + // Open and release two more conns, check "foo" is closed (since LRU size is 2) + for i := 0; i < 2; i++ { + _, r, err := c.Acquire(context.Background(), fmt.Sprintf("bar%d", i)) + require.NoError(t, err) + require.Equal(t, int64(1+i+1), opens.Load()) + r() + } + time.Sleep(time.Second) + require.Equal(t, true, m1.(*mockConn).closeCalled.Load()) + + // Close cache + require.NoError(t, c.Close(context.Background())) +} + +func TestConcurrentOpen(t *testing.T) { + opens := atomic.Int64{} + + c := New(Options{ + MaxIdleConnections: 2, + OpenFunc: func(ctx context.Context, cfg any) (Connection, error) { + opens.Add(1) + time.Sleep(time.Second) + return &mockConn{cfg: cfg.(string)}, nil + }, + KeyFunc: func(cfg any) string { + return cfg.(string) + }, + }) + + var m1, m2 Connection + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + m, _, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + m1 = m + }() + go func() { + defer wg.Done() + m, _, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + m2 = m + }() + + wg.Wait() + require.NotNil(t, m1) + require.Equal(t, m1, m2) + require.Equal(t, int64(1), opens.Load()) + + // Close cache + require.NoError(t, c.Close(context.Background())) +} + +func TestOpenDuringClose(t *testing.T) { + opens := atomic.Int64{} + + c := New(Options{ + MaxIdleConnections: 2, + OpenFunc: func(ctx context.Context, cfg any) (Connection, error) { + opens.Add(1) + return &mockConn{ + cfg: cfg.(string), + closeDelay: time.Second, // Closes hang for 1s + }, nil + }, + KeyFunc: func(cfg any) string { + return cfg.(string) + }, + }) + + // Create conn + m1, r1, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Equal(t, int64(1), opens.Load()) + r1() + + // Evict it so it starts closing + c.EvictWhere(func(cfg any) bool { return true }) + // closeCalled is set before mockConn.Close hangs, but it will take 1s to actually close + time.Sleep(100 * time.Millisecond) + require.True(t, m1.(*mockConn).closeCalled.Load()) + + // Open again, check it takes ~1s to do so + start := time.Now() + m2, r2, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Greater(t, time.Since(start), 500*time.Millisecond) + require.Equal(t, int64(2), opens.Load()) + require.NotEqual(t, m1, m2) + r2() + + // Close cache + require.NoError(t, c.Close(context.Background())) +} + +func TestCloseDuringOpen(t *testing.T) { + opens := atomic.Int64{} + m := &mockConn{cfg: "foo"} + + c := New(Options{ + MaxIdleConnections: 2, + OpenFunc: func(ctx context.Context, cfg any) (Connection, error) { + time.Sleep(time.Second) + opens.Add(1) + return m, nil + }, + KeyFunc: func(cfg any) string { + return cfg.(string) + }, + }) + + // Start opening + go func() { + _, _, err := c.Acquire(context.Background(), "foo") + require.ErrorContains(t, err, "immediately closed") + require.Equal(t, int64(1), opens.Load()) + }() + + // Evict it so it starts closing + time.Sleep(100 * time.Millisecond) // Give it time to start opening + c.EvictWhere(func(cfg any) bool { return true }) + + // It will let the open finish before closing it, so will take ~1s + time.Sleep(2 * time.Second) + require.True(t, m.closeCalled.Load()) + + // Close cache + require.NoError(t, c.Close(context.Background())) +} + +func TestCloseInUse(t *testing.T) { + opens := atomic.Int64{} + + c := New(Options{ + MaxIdleConnections: 2, + OpenFunc: func(ctx context.Context, cfg any) (Connection, error) { + opens.Add(1) + return &mockConn{cfg: cfg.(string)}, nil + }, + KeyFunc: func(cfg any) string { + return cfg.(string) + }, + }) + + // Open conn "foo" + m1, r1, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Equal(t, int64(1), opens.Load()) + + // Evict it, check it's closed even though still in use (r1 not called) + c.EvictWhere(func(cfg any) bool { return true }) + time.Sleep(time.Second) + require.Equal(t, true, m1.(*mockConn).closeCalled.Load()) + + // Open "foo" again, check it opens a new one + m2, r2, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.Equal(t, int64(2), opens.Load()) + require.NotEqual(t, m1, m2) + + // Check that releasing m1 doesn't fail (though it's been closed) + r1() + r2() +} + +func TestHanging(t *testing.T) { + hangingOpens := atomic.Int64{} + hangingCloses := atomic.Int64{} + + c := New(Options{ + MaxIdleConnections: 2, + OpenTimeout: 100 * time.Millisecond, + CloseTimeout: 100 * time.Millisecond, + CheckHangingInterval: 100 * time.Millisecond, + OpenFunc: func(ctx context.Context, cfg any) (Connection, error) { + time.Sleep(time.Second) + return &mockConn{ + cfg: cfg.(string), + closeDelay: time.Second, // Make closes hang for 1s + }, nil + }, + KeyFunc: func(cfg any) string { + return cfg.(string) + }, + HangingFunc: func(cfg any, open bool) { + if open { + hangingOpens.Add(1) + } else { + hangingCloses.Add(1) + } + }, + }) + + // Open conn "foo" + m1, r1, err := c.Acquire(context.Background(), "foo") + require.NoError(t, err) + require.GreaterOrEqual(t, hangingOpens.Load(), int64(1)) + r1() + + // Evict it, check it's closed even though still in use (r1 not called) + c.EvictWhere(func(cfg any) bool { return true }) + time.Sleep(time.Second) + require.Equal(t, true, m1.(*mockConn).closeCalled.Load()) + require.GreaterOrEqual(t, hangingCloses.Load(), int64(1)) +} diff --git a/runtime/registry.go b/runtime/registry.go index ba521929843..093a48e4c57 100644 --- a/runtime/registry.go +++ b/runtime/registry.go @@ -375,7 +375,7 @@ func (r *registryCache) restartController(iwc *instanceWithController) { // So we want to evict all open connections for that instance, but it's unsafe to do so while the controller is running. // So this is the only place where we can do it safely. if r.baseCtx.Err() == nil { - r.rt.connCache.evictAll(r.baseCtx, iwc.instance.ID) + r.rt.evictInstanceConnections(iwc.instance.ID) } r.mu.Lock() diff --git a/runtime/registry_test.go b/runtime/registry_test.go index 5f305499fad..746d424fd4a 100644 --- a/runtime/registry_test.go +++ b/runtime/registry_test.go @@ -331,7 +331,7 @@ func TestRuntime_EditInstance(t *testing.T) { } // Wait for controller restart - time.Sleep(500 * time.Millisecond) + time.Sleep(2 * time.Second) _, err = rt.Controller(ctx, inst.ID) require.NoError(t, err) @@ -424,8 +424,9 @@ func TestRuntime_DeleteInstance(t *testing.T) { require.Error(t, err) // verify older olap connection is closed and cache updated - require.False(t, rt.connCache.lru.Contains(inst.ID+"duckdb"+fmt.Sprintf("dsn:%s ", dbFile))) - require.False(t, rt.connCache.lru.Contains(inst.ID+"file"+fmt.Sprintf("dsn:%s ", repodsn))) + // require.False(t, rt.connCache.lru.Contains(inst.ID+"duckdb"+fmt.Sprintf("dsn:%s ", dbFile))) + // require.False(t, rt.connCache.lru.Contains(inst.ID+"file"+fmt.Sprintf("dsn:%s ", repodsn))) + time.Sleep(2 * time.Second) err = olap.Exec(context.Background(), &drivers.Statement{Query: "SELECT COUNT(*) FROM rill.migration_version"}) require.True(t, err != nil) @@ -474,7 +475,7 @@ func TestRuntime_DeleteInstance_DropCorrupted(t *testing.T) { require.NoError(t, err) // Close OLAP connection - rt.connCache.evictAll(ctx, inst.ID) + rt.evictInstanceConnections(inst.ID) // Corrupt database file err = os.WriteFile(dbpath, []byte("corrupted"), 0644) diff --git a/runtime/runtime.go b/runtime/runtime.go index 46601ac66c1..07e28bc1d5f 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -9,6 +9,7 @@ import ( runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" "github.com/rilldata/rill/runtime/drivers" "github.com/rilldata/rill/runtime/pkg/activity" + "github.com/rilldata/rill/runtime/pkg/conncache" "github.com/rilldata/rill/runtime/pkg/email" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -36,7 +37,7 @@ type Runtime struct { activity activity.Client metastore drivers.Handle registryCache *registryCache - connCache *connectionCache + connCache conncache.Cache queryCache *queryCache securityEngine *securityEngine } @@ -55,7 +56,7 @@ func New(ctx context.Context, opts *Options, logger *zap.Logger, ac activity.Cli securityEngine: newSecurityEngine(opts.SecurityEngineCacheSize, logger), } - rt.connCache = newConnectionCache(opts.ConnectionCacheSize, logger, rt, ac) + rt.connCache = rt.newConnectionCache() store, _, err := rt.AcquireSystemHandle(ctx, opts.MetastoreConnector) if err != nil { @@ -88,7 +89,7 @@ func (r *Runtime) Close() error { defer cancel() err1 := r.registryCache.close(ctx) err2 := r.queryCache.close() - err3 := r.connCache.Close() // Also closes metastore // TODO: Propagate ctx cancellation + err3 := r.connCache.Close(ctx) // Also closes metastore // TODO: Propagate ctx cancellation return errors.Join(err1, err2, err3) }