diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index e96cd551..dd632d5b 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -16,7 +16,7 @@ on: jobs: test: name: Test - runs-on: ubuntu-latest + runs-on: ubuntu-latest-8cpu env: GOEXPERIMENT: loopvar @@ -48,7 +48,7 @@ jobs: run: go vet - name: Run tests - run: go test ./... + run: go test -race -v ./... env: # Environment variables so that AWS resources can be created AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} diff --git a/go.mod b/go.mod index ffb54473..19f28a82 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 // Direct dependencies require ( + github.com/MrAlias/otel-schema-utils v0.2.1-alpha github.com/aws/aws-sdk-go-v2 v1.21.0 github.com/aws/aws-sdk-go-v2/config v1.18.41 github.com/aws/aws-sdk-go-v2/credentials v1.13.39 @@ -28,27 +29,29 @@ require ( github.com/iancoleman/strcase v0.2.0 github.com/nats-io/jwt/v2 v2.5.2 github.com/nats-io/nkeys v0.4.5 - github.com/overmindtech/discovery v0.24.0 + github.com/overmindtech/discovery v0.25.0 github.com/overmindtech/sdp-go v0.49.5 + github.com/overmindtech/sdpcache v1.6.0 github.com/sirupsen/logrus v1.9.3 github.com/sourcegraph/conc v0.3.0 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.16.0 - go.opentelemetry.io/contrib/detectors/aws/ec2 v1.16.1 - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 - go.opentelemetry.io/otel v1.16.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.16.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.16.0 - go.opentelemetry.io/otel/sdk v1.16.0 - go.opentelemetry.io/otel/trace v1.16.0 + go.opentelemetry.io/contrib/detectors/aws/ec2 v1.19.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 + go.opentelemetry.io/otel v1.18.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.18.0 + go.opentelemetry.io/otel/sdk v1.18.0 + go.opentelemetry.io/otel/trace v1.18.0 google.golang.org/protobuf v1.31.0 ) // Transitive dependencies require ( + github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/auth0/go-jwt-middleware/v2 v2.1.0 // indirect - github.com/aws/aws-sdk-go v1.44.285 // indirect + github.com/aws/aws-sdk-go v1.45.7 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.13 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect @@ -77,21 +80,20 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/klauspost/compress v1.16.7 // indirect + github.com/klauspost/compress v1.17.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/nats-io/nats.go v1.29.0 // indirect + github.com/nats-io/nats.go v1.30.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/overmindtech/api-client v0.14.0 // indirect - github.com/overmindtech/sdpcache v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/spf13/afero v1.9.5 // indirect github.com/spf13/cast v1.5.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect - go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.16.0 // indirect - go.opentelemetry.io/otel/metric v1.16.0 // indirect + go.opentelemetry.io/otel/metric v1.18.0 // indirect + go.opentelemetry.io/otel/schema v0.0.5 // indirect go.opentelemetry.io/proto/otlp v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.13.0 // indirect @@ -100,11 +102,12 @@ require ( golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/grpc v1.57.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/grpc v1.58.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bf9689eb..a9a619a9 100644 --- a/go.sum +++ b/go.sum @@ -38,10 +38,14 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/MrAlias/otel-schema-utils v0.2.1-alpha h1:dSeMM04tO+EY1JLof0bL8rIDkaMV3yBiUFdcSeIfbqI= +github.com/MrAlias/otel-schema-utils v0.2.1-alpha/go.mod h1:i5gQR7dVLC4XxJuPITWxpWnjGRICZY50OMXXtrQTDeQ= github.com/auth0/go-jwt-middleware/v2 v2.1.0 h1:VU4LsC3aFPoqXVyEp8EixU6FNM+ZNIjECszRTvtGQI8= github.com/auth0/go-jwt-middleware/v2 v2.1.0/go.mod h1:CpzcJoleayAACpv+vt0AP8/aYn5TDngsqzLapV1nM4c= -github.com/aws/aws-sdk-go v1.44.285 h1:rgoWYl+NdmKzRgoi/fZLEtGXOjCkcWIa5jPH02Uahdo= -github.com/aws/aws-sdk-go v1.44.285/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/aws/aws-sdk-go v1.45.7 h1:k4QsvWZhm8409TYeRuTV1P6+j3lLKoe+giFA/j3VAps= +github.com/aws/aws-sdk-go v1.45.7/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go-v2 v1.21.0 h1:gMT0IW+03wtYJhRqTVYn0wLzwdnK9sRMcxmtfGzRdJc= github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.13 h1:OPLEkmhXf6xFPiz0bLeDArZIDx1NNS4oJyG4nv3Gct0= @@ -238,8 +242,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= -github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= +github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -256,22 +260,22 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/nats-io/jwt/v2 v2.5.2 h1:DhGH+nKt+wIkDxM6qnVSKjokq5t59AZV5HRcFW0zJwU= github.com/nats-io/jwt/v2 v2.5.2/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= -github.com/nats-io/nats-server/v2 v2.9.22 h1:rzl88pqWFFrU4G00ed+JnY+uGHSLZ+3jrxDnJxzKwGA= -github.com/nats-io/nats-server/v2 v2.9.22/go.mod h1:wEjrEy9vnqIGE4Pqz4/c75v9Pmaq7My2IgFmnykc4C0= -github.com/nats-io/nats.go v1.29.0 h1:dSXZ+SZeGyTdHVYeXimeq12FsIpb9dM8CJ2IZFiHcyE= -github.com/nats-io/nats.go v1.29.0/go.mod h1:XpbWUlOElGwTYbMR7imivs7jJj9GtK7ypv321Wp6pjc= +github.com/nats-io/nats-server/v2 v2.10.1 h1:MIJ614dhOIdo71iSzY8ln78miXwrYvlvXHUyS+XdKZQ= +github.com/nats-io/nats-server/v2 v2.10.1/go.mod h1:3PMvMSu2cuK0J9YInRLWdFpFsswKKGUS77zVSAudRto= +github.com/nats-io/nats.go v1.30.0 h1:bj/rVsRCrFXxmm9mJiDhb74UKl2HhKpDwKRBtvCjZjc= +github.com/nats-io/nats.go v1.30.0/go.mod h1:dcfhUgmQNN4GJEfIb2f9R7Fow+gzBF4emzDHrVBd5qM= github.com/nats-io/nkeys v0.4.5 h1:Zdz2BUlFm4fJlierwvGK+yl20IAKUm7eV6AAZXEhkPk= github.com/nats-io/nkeys v0.4.5/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/overmindtech/api-client v0.14.0 h1:zXyjJsIeawNqoWv7FqOjwcqgFpLrDYz7l9MWqh1G9ZQ= github.com/overmindtech/api-client v0.14.0/go.mod h1:msdkTAQFlvDGOU4tQk2adk2P8j23uaMWkJ9YRX4wGWI= -github.com/overmindtech/discovery v0.24.0 h1:TlALbpl76gcR2eG5NQK+eH+7x7pm0zBNRF+BaK5akFE= -github.com/overmindtech/discovery v0.24.0/go.mod h1:SV0lAeunspCtJ/HIFRTo66Zl7/hWqZruIcRO0oLEW/0= +github.com/overmindtech/discovery v0.25.0 h1:IjGGeyph4rnBpajl8wzJ6BtsdxwKthluvoUf8PgS39o= +github.com/overmindtech/discovery v0.25.0/go.mod h1:ZZQvzWiHe09ySzt0d5Gi4u7Zc/ljmSdeedLEKk3ZkIo= github.com/overmindtech/sdp-go v0.49.5 h1:mvmUvnSM6q3PtUaLb+vJxLbq60uQiXGPB2caPaRBQMk= github.com/overmindtech/sdp-go v0.49.5/go.mod h1:q2RBDqmPidIQsYa9g/6nqOeJAyM6j3zgxn94GPCJcF8= -github.com/overmindtech/sdpcache v1.5.0 h1:QzHWQm1KGN9rNHPb/VZvz7WDCsyKOuVLlNUGF2CIFGc= -github.com/overmindtech/sdpcache v1.5.0/go.mod h1:GFMMle860EWMDQXbk6dhLVSQrV0YlEqqJ6/VNxINb0o= +github.com/overmindtech/sdpcache v1.6.0 h1:cIaLXULSltDzvf6Me91h70/ZnypiRbROmrJ+ovAGGpo= +github.com/overmindtech/sdpcache v1.6.0/go.mod h1:SHrj4t9f0x7V3WPYAEf31nukuLSIqg1YeqrSdMwLw8c= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= @@ -283,8 +287,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= @@ -312,8 +316,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -327,24 +332,26 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.opentelemetry.io/contrib/detectors/aws/ec2 v1.16.1 h1:2ZlxACYfrRHXzXRwB1mEcf6iMqhNpsm6Hzk7dNTxqCA= -go.opentelemetry.io/contrib/detectors/aws/ec2 v1.16.1/go.mod h1:+sgB/aIQoZOnLSQDC0yquT1YjBrd67MuGr7gknvRMxE= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8= -go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s= -go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4= -go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.16.0 h1:t4ZwRPU+emrcvM2e9DHd0Fsf0JTPVcbfa/BhTDF03d0= -go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.16.0/go.mod h1:vLarbg68dH2Wa77g71zmKQqlQ8+8Rq3GRG31uc0WcWI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.16.0 h1:cbsD4cUcviQGXdw8+bo5x2wazq10SKz8hEbtCRPcU78= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.16.0/go.mod h1:JgXSGah17croqhJfhByOLVY719k1emAXC8MVhCIJlRs= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.16.0 h1:iqjq9LAB8aK++sKVcELezzn655JnBNdsDhghU4G/So8= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.16.0/go.mod h1:hGXzO5bhhSHZnKvrDaXB82Y9DRFour0Nz/KrBh7reWw= -go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo= -go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4= -go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE= -go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= -go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= -go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= +go.opentelemetry.io/contrib/detectors/aws/ec2 v1.19.0 h1:Dh3v5W0qYTDQrZlygcQDC/Fa7prK1uQSHG46ZYMM2MQ= +go.opentelemetry.io/contrib/detectors/aws/ec2 v1.19.0/go.mod h1:+mpf+EyLaP+xNXFAyRjWguw5yRo0T6oSn6fYqjYAOa4= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 h1:KfYpVmrjI7JuToy5k8XV3nkapjWx48k4E4JOtVstzQI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0/go.mod h1:SeQhzAEccGVZVEy7aH87Nh0km+utSpo1pTv6eMMop48= +go.opentelemetry.io/otel v1.18.0 h1:TgVozPGZ01nHyDZxK5WGPFB9QexeTMXEH7+tIClWfzs= +go.opentelemetry.io/otel v1.18.0/go.mod h1:9lWqYO0Db579XzVuCKFNPDl4s73Voa+zEck3wHaAYQI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 h1:IAtl+7gua134xcV3NieDhJHjjOVeJhXAnYf/0hswjUY= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0/go.mod h1:w+pXobnBzh95MNIkeIuAKcHe/Uu/CX2PKIvBP6ipKRA= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.18.0 h1:6pu8ttx76BxHf+xz/H77AUZkPF3cwWzXqAUsXhVKI18= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.18.0/go.mod h1:IOmXxPrxoxFMXdNy7lfDmE8MzE61YPcurbUm0SMjerI= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0 h1:hSWWvDjXHVLq9DkmB+77fl8v7+t+yYiS+eNkiplDK54= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.18.0/go.mod h1:zG7KQql1WjZCaUJd+L/ReSYx4bjbYJxg5ws9ws+mYes= +go.opentelemetry.io/otel/metric v1.18.0 h1:JwVzw94UYmbx3ej++CwLUQZxEODDj/pOuTCvzhtRrSQ= +go.opentelemetry.io/otel/metric v1.18.0/go.mod h1:nNSpsVDjWGfb7chbRLUNW+PBNdcSTHD4Uu5pfFMOI0k= +go.opentelemetry.io/otel/schema v0.0.5 h1:1Mfux3lhQR96w+sftg63sjRwXnThUrvAaJ0NsVIbXW4= +go.opentelemetry.io/otel/schema v0.0.5/go.mod h1:4tYaqZ/pYJE+z0U0Z6KpHMfol5bg17/Dn6HVXlijKVo= +go.opentelemetry.io/otel/sdk v1.18.0 h1:e3bAB0wB3MljH38sHzpV/qWrOTCFrdZF2ct9F8rBkcY= +go.opentelemetry.io/otel/sdk v1.18.0/go.mod h1:1RCygWV7plY2KmdskZEDDBs4tJeHG92MdHZIluiYs/M= +go.opentelemetry.io/otel/trace v1.18.0 h1:NY+czwbHbmndxojTEKiSMHkG2ClNH2PwmcHrdo0JY10= +go.opentelemetry.io/otel/trace v1.18.0/go.mod h1:T2+SGJGuYZY3bjj5rgh/hN7KIrlpWC5nS8Mjvzckz+0= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -628,12 +635,12 @@ google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e h1:Ao9GzfUMPH3zjVfzXG5rlWlk+Q8MXWKwWpwVQE1MXfw= -google.golang.org/genproto v0.0.0-20230526203410-71b5a4ffd15e/go.mod h1:zqTuNwFlFRsw5zIts5VnzLQxSRqh+CGOTVMlYbY0Eyk= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= +google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98 h1:Z0hjGZePRE0ZBWotvtrwxFNrNE9CUAGtplaDK5NNI/g= +google.golang.org/genproto v0.0.0-20230711160842-782d3b101e98/go.mod h1:S7mY02OqCJTD0E1OiQy1F72PWFB4bZJ87cAtLPYgDR0= +google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98 h1:FmF5cCW94Ij59cfpoLiwTgodWmm60eEV0CjlsVg2fuw= +google.golang.org/genproto/googleapis/api v0.0.0-20230711160842-782d3b101e98/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -650,8 +657,8 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw= -google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo= +google.golang.org/grpc v1.58.0 h1:32JY8YpPMSR45K+c3o6b8VL73V+rR8k+DeMIr4vRH8o= +google.golang.org/grpc v1.58.0/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/sources/always_get_source.go b/sources/always_get_source.go index 0c959727..c0643cec 100644 --- a/sources/always_get_source.go +++ b/sources/always_get_source.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "sync" + "time" "github.com/getsentry/sentry-go" "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdpcache" log "github.com/sirupsen/logrus" ) @@ -67,6 +69,33 @@ type AlwaysGetSource[ListInput InputType, ListOutput OutputType, GetInput InputT // of inputs to pass to the GetFunc. The input used for the ListFunc is also // included in case it is required ListFuncOutputMapper func(output ListOutput, input ListInput) ([]GetInput, error) + + CacheDuration time.Duration // How long to cache items for + cache *sdpcache.Cache // The sdpcache of this source + cacheInitMu sync.Mutex // Mutex to ensure cache is only initialised once +} + +// DefaultCacheDuration Returns the default cache duration for this source +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) DefaultCacheDuration() time.Duration { + if s.CacheDuration == 0 { + return 10 * time.Minute + } + + return s.CacheDuration +} + +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) ensureCache() { + s.cacheInitMu.Lock() + defer s.cacheInitMu.Unlock() + + if s.cache == nil { + s.cache = sdpcache.NewCache() + } +} + +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Cache() *sdpcache.Cache { + s.ensureCache() + return s.cache } // Validate Checks that the source has been set up correctly @@ -106,7 +135,7 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc } } -func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Get(ctx context.Context, scope string, query string) (*sdp.Item, error) { +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Get(ctx context.Context, scope string, query string, ignoreCache bool) (*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -121,21 +150,37 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc return nil, WrapAWSError(err) } + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + if len(cachedItems) > 0 { + return cachedItems[0], nil + } else { + return nil, nil + } + } + input := s.GetInputMapper(scope, query) item, err = s.GetFunc(ctx, s.Client, scope, input) if err != nil { // TODO: How can we handle NOTFOUND? - return nil, WrapAWSError(err) + qErr := WrapAWSError(err) + s.cache.StoreError(qErr, s.CacheDuration, ck) + return nil, qErr } + s.cache.StoreItem(item, s.CacheDuration, ck) return item, nil } // List Lists all available items. This is done by running the ListFunc, then // passing these results to GetFunc in order to get the details -func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) List(ctx context.Context, scope string) ([]*sdp.Item, error) { +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) List(ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -149,7 +194,27 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc return []*sdp.Item{}, nil } - return s.listInternal(ctx, scope, s.ListInput) + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + return cachedItems, nil + } + + items, err := s.listInternal(ctx, scope, s.ListInput) + if err != nil { + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err + } + + for _, item := range items { + s.cache.StoreItem(item, s.CacheDuration, ck) + } + + return items, nil } // listInternal Accepts a ListInput and runs the List logic against it @@ -248,7 +313,7 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc } // Search Searches for AWS resources by ARN -func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Search(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Search(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -256,18 +321,37 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc } } + ck := sdpcache.CacheKeyFromParts(s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query) + + var items []*sdp.Item + var err error + if s.SearchInputMapper == nil { - return s.SearchARN(ctx, scope, query) + items, err = s.SearchARN(ctx, scope, query, ignoreCache) } else { // If we should always look for ARNs first, do that if s.AlwaysSearchARNs { - if _, err := ParseARN(query); err == nil { - return s.SearchARN(ctx, scope, query) + if _, err = ParseARN(query); err == nil { + items, err = s.SearchARN(ctx, scope, query, ignoreCache) + } else { + items, err = s.SearchCustom(ctx, scope, query) } + } else { + items, err = s.SearchCustom(ctx, scope, query) } + } + + if err != nil { + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err + } - return s.SearchCustom(ctx, scope, query) + for _, item := range items { + s.cache.StoreItem(item, s.CacheDuration, ck) } + + return items, nil } // SearchCustom Searches using custom mapping logic. The SearchInputMapper is @@ -279,10 +363,23 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc return nil, WrapAWSError(err) } - return s.listInternal(ctx, scope, input) + items, err := s.listInternal(ctx, scope, input) + + ck := sdpcache.CacheKeyFromParts(s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query) + + if err != nil { + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err + } + + for _, item := range items { + s.cache.StoreItem(item, s.CacheDuration, ck) + } + return items, nil } -func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchARN(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchARN(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { // Parse the ARN a, err := ParseARN(query) @@ -298,8 +395,7 @@ func (s *AlwaysGetSource[ListInput, ListOutput, GetInput, GetOutput, ClientStruc } } - item, err := s.Get(ctx, scope, a.ResourceID()) - + item, err := s.Get(ctx, scope, a.ResourceID(), ignoreCache) if err != nil { return nil, WrapAWSError(err) } diff --git a/sources/always_get_source_test.go b/sources/always_get_source_test.go index 16ac3094..5925ec2d 100644 --- a/sources/always_get_source_test.go +++ b/sources/always_get_source_test.go @@ -3,9 +3,11 @@ package sources import ( "context" "errors" + "fmt" "testing" "github.com/overmindtech/sdp-go" + "google.golang.org/protobuf/types/known/structpb" ) func TestMaxParallel(t *testing.T) { @@ -57,7 +59,9 @@ func TestAlwaysGetSourceGet(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -71,7 +75,7 @@ func TestAlwaysGetSourceGet(t *testing.T) { }, } - _, err := lgs.Get(context.Background(), "foo.bar", "") + _, err := lgs.Get(context.Background(), "foo.bar", "", false) if err != nil { t.Error(err) @@ -87,7 +91,9 @@ func TestAlwaysGetSourceGet(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -101,7 +107,7 @@ func TestAlwaysGetSourceGet(t *testing.T) { }, } - _, err := lgs.Get(context.Background(), "foo.bar", "") + _, err := lgs.Get(context.Background(), "foo.bar", "", false) if err == nil { t.Error("expected error") @@ -120,7 +126,9 @@ func TestAlwaysGetSourceList(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -134,7 +142,7 @@ func TestAlwaysGetSourceList(t *testing.T) { }, } - items, err := lgs.List(context.Background(), "foo.bar") + items, err := lgs.List(context.Background(), "foo.bar", false) if err != nil { t.Error(err) @@ -155,7 +163,9 @@ func TestAlwaysGetSourceList(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -169,14 +179,19 @@ func TestAlwaysGetSourceList(t *testing.T) { }, } - _, err := lgs.List(context.Background(), "foo.bar") + _, err := lgs.List(context.Background(), "foo.bar", false) if err == nil { t.Fatal("expected error but got nil") } - if err.Error() != "output mapper error" { - t.Errorf("expected output mapper error, got %v", err.Error()) + qErr := &sdp.QueryError{} + if !errors.As(err, &qErr) { + t.Errorf("expected error to be a QueryError, got %v", err) + } else { + if qErr.ErrorString != "output mapper error" { + t.Errorf("expected 'output mapper error', got '%v'", qErr.ErrorString) + } } }) @@ -190,7 +205,9 @@ func TestAlwaysGetSourceList(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -204,7 +221,7 @@ func TestAlwaysGetSourceList(t *testing.T) { }, } - items, err := lgs.List(context.Background(), "foo.bar") + items, err := lgs.List(context.Background(), "foo.bar", false) // If GetFunc fails it doesn't cause an error if err != nil { @@ -228,7 +245,9 @@ func TestAlwaysGetSourceSearch(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -247,7 +266,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { } t.Run("bad ARN", func(t *testing.T) { - _, err := lgs.Search(context.Background(), "foo.bar", "query") + _, err := lgs.Search(context.Background(), "foo.bar", "query", false) if err == nil { t.Error("expected error because the ARN was bad") @@ -255,7 +274,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }) t.Run("good ARN but bad scope", func(t *testing.T) { - _, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:region:account:type/id") + _, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:region:account:type/id", false) if err == nil { t.Error("expected error because the ARN had a bad scope") @@ -263,7 +282,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }) t.Run("good ARN", func(t *testing.T) { - _, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id") + _, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id", false) if err != nil { t.Error(err) @@ -285,7 +304,9 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }, ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -304,7 +325,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { } t.Run("ARN", func(t *testing.T) { - items, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id") + items, err := lgs.Search(context.Background(), "foo.bar", "arn:aws:service:bar:foo:type/id", false) if err != nil { t.Error(err) @@ -316,7 +337,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }) t.Run("other search", func(t *testing.T) { - items, err := lgs.Search(context.Background(), "foo.bar", "id") + items, err := lgs.Search(context.Background(), "foo.bar", "id", false) if err != nil { t.Error(err) @@ -339,7 +360,9 @@ func TestAlwaysGetSourceSearch(t *testing.T) { ListInput: "", ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { // Returns 3 pages - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, ListFuncOutputMapper: func(output, input string) ([]string, error) { // Returns 2 gets per page @@ -357,7 +380,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { }, } - _, err := lgs.Search(context.Background(), "foo.bar", "bar") + _, err := lgs.Search(context.Background(), "foo.bar", "bar", false) if err != nil { t.Error(err) @@ -368,3 +391,163 @@ func TestAlwaysGetSourceSearch(t *testing.T) { } }) } + +func TestAlwaysGetSourceCaching(t *testing.T) { + ctx := context.Background() + generation := 0 + s := AlwaysGetSource[string, string, string, string, struct{}, struct{}]{ + ItemType: "test", + AccountID: "foo", + Region: "eu-west-2", + Client: struct{}{}, + ListInput: "", + ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { + return &TestPaginator{ + DataFunc: func() string { + generation += 1 + return fmt.Sprintf("%v", generation) + }, + MaxPages: 1, + } + }, + ListFuncOutputMapper: func(output, input string) ([]string, error) { + // Returns only 1 get per page to avoid confusing the cache with duplicate items + return []string{""}, nil + }, + GetFunc: func(ctx context.Context, client struct{}, scope, input string) (*sdp.Item, error) { + generation += 1 + return &sdp.Item{Scope: "foo.eu-west-2", + Type: "test-type", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{ + AttrStruct: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "name": structpb.NewStringValue("test-item"), + "generation": structpb.NewStringValue(fmt.Sprintf("%v%v", input, generation)), + }, + }, + }}, nil + }, + GetInputMapper: func(scope, query string) string { + return "" + }, + } + + t.Run("get", func(t *testing.T) { + // get + first, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // get again + withCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // get ignore cache + withoutCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("list", func(t *testing.T) { + // list + first, err := s.List(ctx, "foo.eu-west-2", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // list again + withCache, err := s.List(ctx, "foo.eu-west-2", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // list ignore cache + withoutCache, err := s.List(ctx, "foo.eu-west-2", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("search", func(t *testing.T) { + // search + first, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // search again + withCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // search ignore cache + withoutCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) +} diff --git a/sources/describe_source.go b/sources/describe_source.go index 6ac54854..7f7c559d 100644 --- a/sources/describe_source.go +++ b/sources/describe_source.go @@ -5,9 +5,12 @@ import ( "errors" "fmt" "strings" + "sync" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdpcache" ) // DescribeOnlySource Generates a source for AWS APIs that only use a `Describe` @@ -18,6 +21,10 @@ type DescribeOnlySource[Input InputType, Output OutputType, ClientStruct ClientS MaxResultsPerPage int32 // Max results per page when making API queries ItemType string // The type of items that will be returned + CacheDuration time.Duration // How long to cache items for + cache *sdpcache.Cache // The sdpcache of this source + cacheInitMu sync.Mutex // Mutex to ensure cache is only initialised once + // The function that should be used to describe the resources that this // source is related to DescribeFunc func(ctx context.Context, client ClientStruct, input Input) (Output, error) @@ -56,6 +63,29 @@ type DescribeOnlySource[Input InputType, Output OutputType, ClientStruct ClientS Client ClientStruct } +// DefaultCacheDuration Returns the default cache duration for this source +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) DefaultCacheDuration() time.Duration { + if s.CacheDuration == 0 { + return 10 * time.Minute + } + + return s.CacheDuration +} + +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) ensureCache() { + s.cacheInitMu.Lock() + defer s.cacheInitMu.Unlock() + + if s.cache == nil { + s.cache = sdpcache.NewCache() + } +} + +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Cache() *sdpcache.Cache { + s.ensureCache() + return s.cache +} + // Validate Checks that the source is correctly set up and returns an error if // not func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Validate() error { @@ -104,7 +134,7 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Scopes() []st // ctx parameter contains a golang context object which should be used to allow // this source to timeout or be cancelled when executing potentially // long-running actions -func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Get(ctx context.Context, scope string, query string) (*sdp.Item, error) { +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Get(ctx context.Context, scope string, query string, ignoreCache bool) (*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -118,29 +148,44 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Get(ctx conte var items []*sdp.Item err = s.Validate() - if err != nil { return nil, WrapAWSError(err) } + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + if len(cachedItems) > 0 { + return cachedItems[0], nil + } else { + return nil, nil + } + } + // Get the input object input, err = s.InputMapperGet(scope, query) - if err != nil { - return nil, WrapAWSError(err) + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err } // Call the API using the object output, err = s.DescribeFunc(ctx, s.Client, input) - if err != nil { - return nil, WrapAWSError(err) + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err } items, err = s.OutputMapper(scope, input, output) - if err != nil { - return nil, WrapAWSError(err) + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err } numItems := len(items) @@ -154,22 +199,28 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Get(ctx conte itemNames[i] = items[i].GloballyUniqueName() } - return nil, &sdp.QueryError{ + qErr := &sdp.QueryError{ ErrorType: sdp.QueryError_OTHER, ErrorString: fmt.Sprintf("Request returned > 1 item for a GET request. Items: %v", strings.Join(itemNames, ", ")), } + s.cache.StoreError(qErr, s.CacheDuration, ck) + + return nil, qErr case numItems == 0: - return nil, &sdp.QueryError{ + qErr := &sdp.QueryError{ ErrorType: sdp.QueryError_NOTFOUND, ErrorString: fmt.Sprintf("%v %v not found", s.Type(), query), } + s.cache.StoreError(qErr, s.CacheDuration, ck) + return nil, qErr } + s.cache.StoreItem(items[0], s.CacheDuration, ck) return items[0], nil } // List Lists all items in a given scope -func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) List(ctx context.Context, scope string) ([]*sdp.Item, error) { +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) List(ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -185,30 +236,44 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) List(ctx cont } err := s.Validate() - if err != nil { return nil, WrapAWSError(err) } + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + return cachedItems, nil + } + var items []*sdp.Item input, err := s.InputMapperList(scope) - if err != nil { - return nil, WrapAWSError(err) + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err } items, err = s.describe(ctx, input, scope) - if err != nil { - return nil, WrapAWSError(err) + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err + } + + for _, item := range items { + s.cache.StoreItem(item, s.CacheDuration, ck) } return items, nil } // Search Searches for AWS resources by ARN -func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Search(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Search(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -216,14 +281,16 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) Search(ctx co } } + ck := sdpcache.CacheKeyFromParts(s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query) + if s.InputMapperSearch == nil { - return s.searchARN(ctx, scope, query) + return s.searchARN(ctx, scope, query, ignoreCache) } else { - return s.searchCustom(ctx, scope, query) + return s.searchCustom(ctx, scope, query, ck) } } -func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchARN(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchARN(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { // Parse the ARN a, err := ParseARN(query) @@ -239,8 +306,8 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchARN(ctx } } - item, err := s.Get(ctx, scope, a.ResourceID()) - + // this already uses the cache, so needs no extra handling + item, err := s.Get(ctx, scope, a.ResourceID(), ignoreCache) if err != nil { return nil, WrapAWSError(err) } @@ -249,17 +316,21 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchARN(ctx } // searchCustom Runs custom search logic using the `InputMapperSearch` function -func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchCustom(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) searchCustom(ctx context.Context, scope string, query string, ck sdpcache.CacheKey) ([]*sdp.Item, error) { input, err := s.InputMapperSearch(ctx, s.Client, scope, query) - if err != nil { return nil, WrapAWSError(err) } items, err := s.describe(ctx, input, scope) - if err != nil { - return nil, WrapAWSError(err) + err = WrapAWSError(err) + s.cache.StoreError(err, s.CacheDuration, ck) + return nil, err + } + + for _, item := range items { + s.cache.StoreItem(item, s.CacheDuration, ck) } return items, nil @@ -279,13 +350,11 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) describe(ctx for paginator.HasMorePages() { output, err = paginator.NextPage(ctx) - if err != nil { return nil, err } newItems, err = s.OutputMapper(scope, input, output) - if err != nil { return nil, err } @@ -294,13 +363,11 @@ func (s *DescribeOnlySource[Input, Output, ClientStruct, Options]) describe(ctx } } else { output, err = s.DescribeFunc(ctx, s.Client, input) - if err != nil { return nil, err } items, err = s.OutputMapper(scope, input, output) - if err != nil { return nil, err } diff --git a/sources/describe_source_test.go b/sources/describe_source_test.go index 3f8aa450..5a8a9a47 100644 --- a/sources/describe_source_test.go +++ b/sources/describe_source_test.go @@ -3,14 +3,22 @@ package sources import ( "context" "errors" + "fmt" + "os" "regexp" "testing" "github.com/aws/aws-sdk-go-v2/aws" "github.com/overmindtech/sdp-go" + "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/structpb" ) +func TestMain(m *testing.M) { + logrus.SetLevel(logrus.TraceLevel) + os.Exit(m.Run()) +} + func TestType(t *testing.T) { s := DescribeOnlySource[string, string, struct{}, struct{}]{ ItemType: "foo", @@ -81,7 +89,7 @@ func TestGet(t *testing.T) { }, } - item, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + item, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err != nil { t.Error(err) @@ -128,7 +136,7 @@ func TestGet(t *testing.T) { }, } - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error") @@ -155,7 +163,7 @@ func TestGet(t *testing.T) { }, } - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error") @@ -185,7 +193,7 @@ func TestSearchARN(t *testing.T) { }, } - items, err := s.Search(context.Background(), "account-id.region", "arn:partition:service:region:account-id:resource-type:resource-id") + items, err := s.Search(context.Background(), "account-id.region", "arn:partition:service:region:account-id:resource-type:resource-id", false) if err != nil { t.Error(err) @@ -231,7 +239,7 @@ func TestSearchCustom(t *testing.T) { }, } - items, err := s.Search(context.Background(), "account-id.region", "foo") + items, err := s.Search(context.Background(), "account-id.region", "foo", false) if err != nil { t.Fatal(err) @@ -263,7 +271,7 @@ func TestNoInputMapper(t *testing.T) { } t.Run("Get", func(t *testing.T) { - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error but got nil") @@ -271,7 +279,7 @@ func TestNoInputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2") + _, err := s.List(context.Background(), "foo.eu-west-2", false) if err == nil { t.Error("expected error but got nil") @@ -297,7 +305,7 @@ func TestNoOutputMapper(t *testing.T) { } t.Run("Get", func(t *testing.T) { - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error but got nil") @@ -305,7 +313,7 @@ func TestNoOutputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2") + _, err := s.List(context.Background(), "foo.eu-west-2", false) if err == nil { t.Error("expected error but got nil") @@ -333,7 +341,7 @@ func TestNoDescribeFunc(t *testing.T) { } t.Run("Get", func(t *testing.T) { - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error but got nil") @@ -341,7 +349,7 @@ func TestNoDescribeFunc(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2") + _, err := s.List(context.Background(), "foo.eu-west-2", false) if err == nil { t.Error("expected error but got nil") @@ -374,7 +382,7 @@ func TestFailingInputMapper(t *testing.T) { fooBar := regexp.MustCompile("foobar") t.Run("Get", func(t *testing.T) { - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error but got nil") @@ -386,7 +394,7 @@ func TestFailingInputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2") + _, err := s.List(context.Background(), "foo.eu-west-2", false) if err == nil { t.Error("expected error but got nil") @@ -421,7 +429,7 @@ func TestFailingOutputMapper(t *testing.T) { fooBar := regexp.MustCompile("foobar") t.Run("Get", func(t *testing.T) { - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error but got nil") @@ -433,7 +441,7 @@ func TestFailingOutputMapper(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2") + _, err := s.List(context.Background(), "foo.eu-west-2", false) if err == nil { t.Error("expected error but got nil") @@ -470,7 +478,7 @@ func TestFailingDescribeFunc(t *testing.T) { fooBar := regexp.MustCompile("foobar") t.Run("Get", func(t *testing.T) { - _, err := s.Get(context.Background(), "foo.eu-west-2", "bar") + _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) if err == nil { t.Error("expected error but got nil") @@ -482,7 +490,7 @@ func TestFailingDescribeFunc(t *testing.T) { }) t.Run("List", func(t *testing.T) { - _, err := s.List(context.Background(), "foo.eu-west-2") + _, err := s.List(context.Background(), "foo.eu-west-2", false) if err == nil { t.Error("expected error but got nil") @@ -495,17 +503,24 @@ func TestFailingDescribeFunc(t *testing.T) { } type TestPaginator struct { + DataFunc func() string + + MaxPages int + page int } func (t *TestPaginator) HasMorePages() bool { - return t.page < 3 + if t.MaxPages == 0 { + t.MaxPages = 3 + } + return t.page < t.MaxPages } func (t *TestPaginator) NextPage(context.Context, ...func(struct{})) (string, error) { + data := t.DataFunc() t.page++ - - return "", nil + return data, nil } func TestPaginated(t *testing.T) { @@ -527,7 +542,9 @@ func TestPaginated(t *testing.T) { }, nil }, PaginatorBuilder: func(client struct{}, params string) Paginator[string, struct{}] { - return &TestPaginator{} + return &TestPaginator{DataFunc: func() string { + return "foo" + }} }, DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil @@ -544,8 +561,8 @@ func TestPaginated(t *testing.T) { } }) - t.Run("paginsting a List query", func(t *testing.T) { - items, err := s.List(context.Background(), "foo.eu-west-2") + t.Run("paginating a List query", func(t *testing.T) { + items, err := s.List(context.Background(), "foo.eu-west-2", false) if err != nil { t.Error(err) @@ -556,3 +573,170 @@ func TestPaginated(t *testing.T) { } }) } + +func TestDescribeOnlySourceCaching(t *testing.T) { + ctx := context.Background() + generation := 0 + s := DescribeOnlySource[string, string, struct{}, struct{}]{ + ItemType: "test-type", + MaxResultsPerPage: 1, + Config: aws.Config{ + Region: "eu-west-2", + }, + AccountID: "foo", + InputMapperGet: func(scope, query string) (string, error) { + return "input", nil + }, + InputMapperList: func(scope string) (string, error) { + return "input", nil + }, + OutputMapper: func(scope, input, output string) ([]*sdp.Item, error) { + return []*sdp.Item{ + { + Scope: "foo.eu-west-2", + Type: "test-type", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{ + AttrStruct: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "name": structpb.NewStringValue("test-item"), + "generation": structpb.NewStringValue(output), + }, + }, + }, + }, + }, nil + }, + PaginatorBuilder: func(client struct{}, params string) Paginator[string, struct{}] { + return &TestPaginator{ + DataFunc: func() string { + generation += 1 + return fmt.Sprintf("%v", generation) + }, + MaxPages: 1, + } + }, + DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { + generation += 1 + return fmt.Sprintf("%v", generation), nil + }, + } + + t.Run("get", func(t *testing.T) { + // get + first, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // get again + withCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // get ignore cache + withoutCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("list", func(t *testing.T) { + // list + first, err := s.List(ctx, "foo.eu-west-2", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // list again + withCache, err := s.List(ctx, "foo.eu-west-2", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // list ignore cache + withoutCache, err := s.List(ctx, "foo.eu-west-2", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("search", func(t *testing.T) { + // search + first, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // search again + withCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // search ignore cache + withoutCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) +} diff --git a/sources/ecs/capacity_provider_test.go b/sources/ecs/capacity_provider_test.go index 4741d083..02053e88 100644 --- a/sources/ecs/capacity_provider_test.go +++ b/sources/ecs/capacity_provider_test.go @@ -127,7 +127,7 @@ func TestCapacityProviderSource(t *testing.T) { // Override the client src.Client = &TestClient{} - items, err := src.List(context.Background(), "") + items, err := src.List(context.Background(), "", false) if err != nil { t.Error(err) diff --git a/sources/get_list_source.go b/sources/get_list_source.go index c0c33989..90592ad8 100644 --- a/sources/get_list_source.go +++ b/sources/get_list_source.go @@ -4,20 +4,25 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdpcache" ) // GetListSource A source for AWS APIs where the Get and List functions both // return the full item, such as many of the IAM APIs type GetListSource[AWSItem AWSItemType, ClientStruct ClientStructType, Options OptionsType] struct { - ItemType string // The type of items that will be returned - Client ClientStruct // The AWS API client - AccountID string // The AWS account ID - Region string // The AWS region this is related to - SupportGlobalResources bool // If true, this will also support resources in the "aws" scope which are global - CacheDuration time.Duration // How long to cache items for + ItemType string // The type of items that will be returned + Client ClientStruct // The AWS API client + AccountID string // The AWS account ID + Region string // The AWS region this is related to + SupportGlobalResources bool // If true, this will also support resources in the "aws" scope which are global + + CacheDuration time.Duration // How long to cache items for + cache *sdpcache.Cache // The sdpcache of this source + cacheInitMu sync.Mutex // Mutex to ensure cache is only initialised once // Disables List(), meaning all calls will return empty results. This does // not affect Search() @@ -38,6 +43,20 @@ type GetListSource[AWSItem AWSItemType, ClientStruct ClientStructType, Options O ItemMapper func(scope string, awsItem AWSItem) (*sdp.Item, error) } +func (s *GetListSource[AWSItem, ClientStruct, Options]) ensureCache() { + s.cacheInitMu.Lock() + defer s.cacheInitMu.Unlock() + + if s.cache == nil { + s.cache = sdpcache.NewCache() + } +} + +func (s *GetListSource[AWSItem, ClientStruct, Options]) Cache() *sdpcache.Cache { + s.ensureCache() + return s.cache +} + // Validate Checks that the source has been set up correctly func (s *GetListSource[AWSItem, ClientStruct, Options]) Validate() error { if s.GetFunc == nil { @@ -105,7 +124,7 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) hasScope(scope string) b return false } -func (s *GetListSource[AWSItem, ClientStruct, Options]) Get(ctx context.Context, scope string, query string) (*sdp.Item, error) { +func (s *GetListSource[AWSItem, ClientStruct, Options]) Get(ctx context.Context, scope string, query string, ignoreCache bool) (*sdp.Item, error) { if !s.hasScope(scope) { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -113,24 +132,39 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) Get(ctx context.Context, } } - awsItem, err := s.GetFunc(ctx, s.Client, scope, query) + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + if len(cachedItems) == 0 { + return nil, nil + } else { + return cachedItems[0], nil + } + } + awsItem, err := s.GetFunc(ctx, s.Client, scope, query) if err != nil { + s.cache.StoreError(err, s.CacheDuration, ck) return nil, WrapAWSError(err) } item, err := s.ItemMapper(scope, awsItem) - if err != nil { + s.cache.StoreError(err, s.CacheDuration, ck) return nil, WrapAWSError(err) } + s.cache.StoreItem(item, s.CacheDuration, ck) + return item, nil } // List Lists all available items. This is done by running the ListFunc, then // passing these results to GetFunc in order to get the details -func (s *GetListSource[AWSItem, ClientStruct, Options]) List(ctx context.Context, scope string) ([]*sdp.Item, error) { +func (s *GetListSource[AWSItem, ClientStruct, Options]) List(ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { if !s.hasScope(scope) { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -142,31 +176,37 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) List(ctx context.Context return []*sdp.Item{}, nil } - awsItems, err := s.ListFunc(ctx, s.Client, scope) + s.ensureCache() + cacheHit, ck, cachedItems, qErr := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + return cachedItems, nil + } + awsItems, err := s.ListFunc(ctx, s.Client, scope) if err != nil { return nil, WrapAWSError(err) } items := make([]*sdp.Item, 0) - - var item *sdp.Item - for _, awsItem := range awsItems { - item, err = s.ItemMapper(scope, awsItem) + item, err := s.ItemMapper(scope, awsItem) if err != nil { continue } items = append(items, item) + s.cache.StoreItem(item, s.CacheDuration, ck) } return items, nil } // Search Searches for AWS resources by ARN -func (s *GetListSource[AWSItem, ClientStruct, Options]) Search(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *GetListSource[AWSItem, ClientStruct, Options]) Search(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { if !s.hasScope(scope) { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -175,13 +215,13 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) Search(ctx context.Conte } if s.SearchFunc != nil { - return s.SearchCustom(ctx, scope, query) + return s.SearchCustom(ctx, scope, query, ignoreCache) } else { - return s.SearchARN(ctx, scope, query) + return s.SearchARN(ctx, scope, query, ignoreCache) } } -func (s *GetListSource[AWSItem, ClientStruct, Options]) SearchARN(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *GetListSource[AWSItem, ClientStruct, Options]) SearchARN(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { // Parse the ARN a, err := ParseARN(query) @@ -197,7 +237,7 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) SearchARN(ctx context.Co } } - item, err := s.Get(ctx, scope, a.ResourceID()) + item, err := s.Get(ctx, scope, a.ResourceID(), ignoreCache) if err != nil { return nil, WrapAWSError(err) @@ -206,7 +246,7 @@ func (s *GetListSource[AWSItem, ClientStruct, Options]) SearchARN(ctx context.Co return []*sdp.Item{item}, nil } -func (s *GetListSource[AWSItem, ClientStruct, Options]) SearchCustom(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *GetListSource[AWSItem, ClientStruct, Options]) SearchCustom(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { awsItems, err := s.SearchFunc(ctx, s.Client, scope, query) if err != nil { diff --git a/sources/get_list_source_test.go b/sources/get_list_source_test.go index 7f76d7e3..5951782c 100644 --- a/sources/get_list_source_test.go +++ b/sources/get_list_source_test.go @@ -3,9 +3,11 @@ package sources import ( "context" "errors" + "fmt" "testing" "github.com/overmindtech/sdp-go" + "google.golang.org/protobuf/types/known/structpb" ) func TestGetListSourceType(t *testing.T) { @@ -56,7 +58,7 @@ func TestGetListSourceGet(t *testing.T) { }, } - if _, err := s.Get(context.Background(), "12345.eu-west-2", ""); err != nil { + if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err != nil { t.Error(err) } }) @@ -77,7 +79,7 @@ func TestGetListSourceGet(t *testing.T) { }, } - if _, err := s.Get(context.Background(), "12345.eu-west-2", ""); err == nil { + if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { t.Error("expected error got nil") } }) @@ -98,7 +100,7 @@ func TestGetListSourceGet(t *testing.T) { }, } - if _, err := s.Get(context.Background(), "12345.eu-west-2", ""); err == nil { + if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { t.Error("expected error got nil") } }) @@ -121,7 +123,7 @@ func TestGetListSourceList(t *testing.T) { }, } - if items, err := s.List(context.Background(), "12345.eu-west-2"); err != nil { + if items, err := s.List(context.Background(), "12345.eu-west-2", false); err != nil { t.Error(err) } else { if len(items) != 2 { @@ -146,7 +148,7 @@ func TestGetListSourceList(t *testing.T) { }, } - if _, err := s.List(context.Background(), "12345.eu-west-2"); err == nil { + if _, err := s.List(context.Background(), "12345.eu-west-2", false); err == nil { t.Error("expected error got nil") } }) @@ -167,7 +169,7 @@ func TestGetListSourceList(t *testing.T) { }, } - if items, err := s.List(context.Background(), "12345.eu-west-2"); err != nil { + if items, err := s.List(context.Background(), "12345.eu-west-2", false); err != nil { t.Error(err) } else { if len(items) != 0 { @@ -195,7 +197,7 @@ func TestGetListSourceSearch(t *testing.T) { } t.Run("bad ARN", func(t *testing.T) { - _, err := s.Search(context.Background(), "12345.eu-west-2", "query") + _, err := s.Search(context.Background(), "12345.eu-west-2", "query", false) if err == nil { t.Error("expected error because the ARN was bad") @@ -203,7 +205,7 @@ func TestGetListSourceSearch(t *testing.T) { }) t.Run("good ARN but bad scope", func(t *testing.T) { - _, err := s.Search(context.Background(), "12345.eu-west-2", "arn:aws:service:region:account:type/id") + _, err := s.Search(context.Background(), "12345.eu-west-2", "arn:aws:service:region:account:type/id", false) if err == nil { t.Error("expected error because the ARN had a bad scope") @@ -211,7 +213,7 @@ func TestGetListSourceSearch(t *testing.T) { }) t.Run("good ARN", func(t *testing.T) { - _, err := s.Search(context.Background(), "12345.eu-west-2", "arn:aws:service:eu-west-2:12345:type/id") + _, err := s.Search(context.Background(), "12345.eu-west-2", "arn:aws:service:eu-west-2:12345:type/id", false) if err != nil { t.Error(err) @@ -219,3 +221,154 @@ func TestGetListSourceSearch(t *testing.T) { }) }) } + +func TestGetListSourceCaching(t *testing.T) { + ctx := context.Background() + generation := 0 + s := GetListSource[string, struct{}, struct{}]{ + ItemType: "test-type", + Region: "eu-west-2", + AccountID: "foo", + GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { + generation += 1 + return fmt.Sprintf("%v", generation), nil + }, + ListFunc: func(ctx context.Context, client struct{}, scope string) ([]string, error) { + generation += 1 + return []string{fmt.Sprintf("%v", generation)}, nil + }, + ItemMapper: func(scope string, output string) (*sdp.Item, error) { + return &sdp.Item{ + Scope: "foo.eu-west-2", + Type: "test-type", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{ + AttrStruct: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "name": structpb.NewStringValue("test-item"), + "generation": structpb.NewStringValue(output), + }, + }, + }, + }, nil + }, + } + + t.Run("get", func(t *testing.T) { + // get + first, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // get again + withCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // get ignore cache + withoutCache, err := s.Get(ctx, "foo.eu-west-2", "test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache.Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("list", func(t *testing.T) { + // list + first, err := s.List(ctx, "foo.eu-west-2", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // list again + withCache, err := s.List(ctx, "foo.eu-west-2", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // list ignore cache + withoutCache, err := s.List(ctx, "foo.eu-west-2", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) + + t.Run("search", func(t *testing.T) { + // search + first, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + if err != nil { + t.Fatal(err) + } + firstGen, err := first[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + // search again + withCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", false) + if err != nil { + t.Fatal(err) + } + withCacheGen, err := withCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + + if firstGen != withCacheGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withCacheGen) + } + + // search ignore cache + withoutCache, err := s.Search(ctx, "foo.eu-west-2", "arn:aws:test-type:eu-west-2:foo:test-item", true) + if err != nil { + t.Fatal(err) + } + withoutCacheGen, err := withoutCache[0].Attributes.Get("generation") + if err != nil { + t.Fatal(err) + } + if withoutCacheGen == firstGen { + t.Errorf("with cache: expected generation %v, got %v", firstGen, withoutCacheGen) + } + }) +} diff --git a/sources/iam/policy_test.go b/sources/iam/policy_test.go index a5cd5eb1..85811402 100644 --- a/sources/iam/policy_test.go +++ b/sources/iam/policy_test.go @@ -227,7 +227,7 @@ func TestNewPolicySource(t *testing.T) { t.Parallel() // This item shouldn't be found since it lives globally - _, err := source.Get(ctx, sources.FormatScope(account, ""), "ReadOnlyAccess") + _, err := source.Get(ctx, sources.FormatScope(account, ""), "ReadOnlyAccess", false) if err == nil { t.Error("expected error, got nil") @@ -240,7 +240,7 @@ func TestNewPolicySource(t *testing.T) { t.Parallel() // This item shouldn't be found since it lives globally - item, err := source.Get(ctx, "aws", "ReadOnlyAccess") + item, err := source.Get(ctx, "aws", "ReadOnlyAccess", false) if err != nil { t.Error(err) @@ -255,7 +255,7 @@ func TestNewPolicySource(t *testing.T) { ctx, span := tracer.Start(context.Background(), t.Name()) defer span.End() - items, err := source.List(ctx, sources.FormatScope(account, "")) + items, err := source.List(ctx, sources.FormatScope(account, ""), false) if err != nil { t.Error(err) @@ -287,7 +287,7 @@ func TestNewPolicySource(t *testing.T) { arn, _ := items[0].Attributes.Get("arn") - _, err := source.Search(ctx, sources.FormatScope(account, ""), arn.(string)) + _, err := source.Search(ctx, sources.FormatScope(account, ""), arn.(string), false) if err != nil { t.Error(err) @@ -302,7 +302,7 @@ func TestNewPolicySource(t *testing.T) { arn, _ := items[0].Attributes.Get("arn") - _, err := source.Search(ctx, "aws", arn.(string)) + _, err := source.Search(ctx, "aws", arn.(string), false) if err == nil { t.Error("expected error, got nil") @@ -314,7 +314,7 @@ func TestNewPolicySource(t *testing.T) { ctx, span := tracer.Start(context.Background(), t.Name()) defer span.End() - items, err := source.List(ctx, "aws") + items, err := source.List(ctx, "aws", false) if err != nil { t.Error(err) @@ -346,7 +346,7 @@ func TestNewPolicySource(t *testing.T) { arn, _ := items[0].Attributes.Get("arn") - _, err := source.Search(ctx, sources.FormatScope(account, ""), arn.(string)) + _, err := source.Search(ctx, sources.FormatScope(account, ""), arn.(string), false) if err == nil { t.Error("expected error, got nil") @@ -361,7 +361,7 @@ func TestNewPolicySource(t *testing.T) { arn, _ := items[0].Attributes.Get("arn") - _, err := source.Search(ctx, "aws", arn.(string)) + _, err := source.Search(ctx, "aws", arn.(string), false) if err != nil { t.Error(err) diff --git a/sources/iam/role.go b/sources/iam/role.go index 727723e2..7b011434 100644 --- a/sources/iam/role.go +++ b/sources/iam/role.go @@ -214,7 +214,7 @@ func roleListFunc(ctx context.Context, client IAMClient, scope string, limit *so Role: role, } - err = enrichRole(ctx, client, &details, limit) + err := enrichRole(ctx, client, &details, limit) if err != nil { return nil, err diff --git a/sources/iam/tracing.go b/sources/iam/tracing.go index a3d89179..37007822 100644 --- a/sources/iam/tracing.go +++ b/sources/iam/tracing.go @@ -2,12 +2,12 @@ package iam import ( "go.opentelemetry.io/otel" - semconv "go.opentelemetry.io/otel/semconv/v1.4.0" + semconv "go.opentelemetry.io/otel/semconv/v1.21.0" "go.opentelemetry.io/otel/trace" ) const ( - instrumentationName = "github.com/overmindtech/gateway/cmd" + instrumentationName = "github.com/overmindtech/aws-source/sources/iam" instrumentationVersion = "0.0.1" ) diff --git a/sources/s3/s3.go b/sources/s3/s3.go index 54af4013..b1d7fa03 100644 --- a/sources/s3/s3.go +++ b/sources/s3/s3.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -11,8 +12,11 @@ import ( "github.com/getsentry/sentry-go" "github.com/overmindtech/aws-source/sources" "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdpcache" ) +const CacheDuration = 10 * time.Minute + // NewS3Source Creates a new S3 source func NewS3Source(config aws.Config, accountID string) *S3Source { return &S3Source{ @@ -62,6 +66,24 @@ type S3Source struct { client *s3.Client clientCreated bool clientMutex sync.Mutex + + CacheDuration time.Duration // How long to cache items for + cache *sdpcache.Cache // The sdpcache of this source + cacheInitMu sync.Mutex // Mutex to ensure cache is only initialised once +} + +func (s *S3Source) ensureCache() { + s.cacheInitMu.Lock() + defer s.cacheInitMu.Unlock() + + if s.cache == nil { + s.cache = sdpcache.NewCache() + } +} + +func (s *S3Source) Cache() *sdpcache.Cache { + s.ensureCache() + return s.cache } func (s *S3Source) Client() *s3.Client { @@ -155,7 +177,7 @@ type Bucket struct { // ctx parameter contains a golang context object which should be used to allow // this source to timeout or be cancelled when executing potentially // long-running actions -func (s *S3Source) Get(ctx context.Context, scope string, query string) (*sdp.Item, error) { +func (s *S3Source) Get(ctx context.Context, scope string, query string, ignoreCache bool) (*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -164,10 +186,23 @@ func (s *S3Source) Get(ctx context.Context, scope string, query string) (*sdp.It } } - return getImpl(ctx, s.Client(), scope, query) + s.ensureCache() + return getImpl(ctx, s.cache, s.Client(), scope, query, ignoreCache) } -func getImpl(ctx context.Context, client S3Client, scope string, query string) (*sdp.Item, error) { +func getImpl(ctx context.Context, cache *sdpcache.Cache, client S3Client, scope string, query string, ignoreCache bool) (*sdp.Item, error) { + cacheHit, ck, cachedItems, qErr := cache.Lookup(ctx, "aws-s3-source", sdp.QueryMethod_GET, scope, "s3-bucket", query, ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + if len(cachedItems) > 0 { + return cachedItems[0], nil + } else { + return nil, nil + } + } + var location *s3.GetBucketLocationOutput var wg sync.WaitGroup var err error @@ -179,7 +214,9 @@ func getImpl(ctx context.Context, client S3Client, scope string, query string) ( }) if err != nil { - return nil, sources.WrapAWSError(err) + err = sources.WrapAWSError(err) + cache.StoreError(err, CacheDuration, ck) + return nil, err } bucket := Bucket{ @@ -344,11 +381,13 @@ func getImpl(ctx context.Context, client S3Client, scope string, query string) ( attributes, err := sources.ToAttributesCase(bucket) if err != nil { - return nil, &sdp.QueryError{ + err = &sdp.QueryError{ ErrorType: sdp.QueryError_OTHER, ErrorString: err.Error(), Scope: scope, } + cache.StoreError(err, CacheDuration, ck) + return nil, err } item := sdp.Item{ @@ -544,11 +583,13 @@ func getImpl(ctx context.Context, client S3Client, scope string, query string) ( } } + cache.StoreItem(&item, CacheDuration, ck) + return &item, nil } // List Lists all items in a given scope -func (s *S3Source) List(ctx context.Context, scope string) ([]*sdp.Item, error) { +func (s *S3Source) List(ctx context.Context, scope string, ignoreCache bool) ([]*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -557,20 +598,35 @@ func (s *S3Source) List(ctx context.Context, scope string) ([]*sdp.Item, error) } } - return listImpl(ctx, s.Client(), scope) + s.ensureCache() + return listImpl(ctx, s.cache, s.Client(), scope, ignoreCache) } -func listImpl(ctx context.Context, client S3Client, scope string) ([]*sdp.Item, error) { +func listImpl(ctx context.Context, cache *sdpcache.Cache, client S3Client, scope string, ignoreCache bool) ([]*sdp.Item, error) { + cacheHit, ck, cachedItems, qErr := cache.Lookup(ctx, "aws-s3-source", sdp.QueryMethod_LIST, scope, "s3-bucket", "", ignoreCache) + if qErr != nil { + return nil, qErr + } + if cacheHit { + if len(cachedItems) > 0 { + return cachedItems, nil + } else { + return nil, nil + } + } + items := make([]*sdp.Item, 0) buckets, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) if err != nil { - return nil, sdp.NewQueryError(err) + err = sdp.NewQueryError(err) + cache.StoreError(err, CacheDuration, ck) + return nil, err } for _, bucket := range buckets.Buckets { - item, err := getImpl(ctx, client, scope, *bucket.Name) + item, err := getImpl(ctx, cache, client, scope, *bucket.Name, ignoreCache) if err != nil { continue @@ -579,11 +635,14 @@ func listImpl(ctx context.Context, client S3Client, scope string) ([]*sdp.Item, items = append(items, item) } + for _, item := range items { + cache.StoreItem(item, CacheDuration, ck) + } return items, nil } // Search Searches for an S3 bucket by ARN rather than name -func (s *S3Source) Search(ctx context.Context, scope string, query string) ([]*sdp.Item, error) { +func (s *S3Source) Search(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { if scope != s.Scopes()[0] { return nil, &sdp.QueryError{ ErrorType: sdp.QueryError_NOSCOPE, @@ -592,10 +651,11 @@ func (s *S3Source) Search(ctx context.Context, scope string, query string) ([]*s } } - return searchImpl(ctx, s.Client(), scope, query) + s.ensureCache() + return searchImpl(ctx, s.cache, s.Client(), scope, query, ignoreCache) } -func searchImpl(ctx context.Context, client S3Client, scope string, query string) ([]*sdp.Item, error) { +func searchImpl(ctx context.Context, cache *sdpcache.Cache, client S3Client, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { // Parse the ARN a, err := sources.ParseARN(query) @@ -612,10 +672,9 @@ func searchImpl(ctx context.Context, client S3Client, scope string, query string } // If the ARN was parsed we can just ask Get for the item - item, err := getImpl(ctx, client, scope, a.ResourceID()) - + item, err := getImpl(ctx, cache, client, scope, a.ResourceID(), ignoreCache) if err != nil { - return nil, sdp.NewQueryError(err) + return nil, err } return []*sdp.Item{item}, nil diff --git a/sources/s3/s3_test.go b/sources/s3/s3_test.go index ec19d817..517a7016 100644 --- a/sources/s3/s3_test.go +++ b/sources/s3/s3_test.go @@ -2,6 +2,7 @@ package s3 import ( "context" + "errors" "testing" "time" @@ -9,11 +10,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/overmindtech/aws-source/sources" "github.com/overmindtech/sdp-go" + "github.com/overmindtech/sdpcache" ) func TestS3SearchImpl(t *testing.T) { + cache := sdpcache.NewCache() t.Run("with a good ARN", func(t *testing.T) { - items, err := searchImpl(context.Background(), TestS3Client{}, "account-id.region", "arn:partition:service:region:account-id:resource-type:resource-id") + items, err := searchImpl(context.Background(), cache, TestS3Client{}, "account-id.region", "arn:partition:service:region:account-id:resource-type:resource-id", false) if err != nil { t.Error(err) @@ -24,7 +27,7 @@ func TestS3SearchImpl(t *testing.T) { }) t.Run("with a bad ARN", func(t *testing.T) { - _, err := searchImpl(context.Background(), TestS3Client{}, "account-id.region", "foo") + _, err := searchImpl(context.Background(), cache, TestS3Client{}, "account-id.region", "foo", false) if err == nil { t.Error("expected error") @@ -40,7 +43,7 @@ func TestS3SearchImpl(t *testing.T) { }) t.Run("with an ARN in another scope", func(t *testing.T) { - _, err := searchImpl(context.Background(), TestS3Client{}, "account-id.region", "arn:partition:service:region:account-id-2:resource-type:resource-id") + _, err := searchImpl(context.Background(), cache, TestS3Client{}, "account-id.region", "arn:partition:service:region:account-id-2:resource-type:resource-id", false) if err == nil { t.Error("expected error") @@ -57,7 +60,8 @@ func TestS3SearchImpl(t *testing.T) { } func TestS3ListImpl(t *testing.T) { - items, err := listImpl(context.Background(), TestS3Client{}, "foo") + cache := sdpcache.NewCache() + items, err := listImpl(context.Background(), cache, TestS3Client{}, "foo", false) if err != nil { t.Error(err) @@ -68,7 +72,8 @@ func TestS3ListImpl(t *testing.T) { } func TestS3GetImpl(t *testing.T) { - item, err := getImpl(context.Background(), TestS3Client{}, "foo", "bar") + cache := sdpcache.NewCache() + item, err := getImpl(context.Background(), cache, TestS3Client{}, "foo", "bar", false) if err != nil { t.Fatal(err) @@ -128,6 +133,37 @@ func TestS3GetImpl(t *testing.T) { tests.Execute(t, item) } +func TestS3SourceCaching(t *testing.T) { + cache := sdpcache.NewCache() + first, err := getImpl(context.Background(), cache, TestS3Client{}, "foo", "bar", false) + if err != nil { + t.Fatal(err) + } + if first == nil { + t.Fatal("expected first item") + } + + second, err := getImpl(context.Background(), cache, TestS3FailClient{}, "foo", "bar", false) + if err != nil { + t.Fatal(err) + } + if second == nil { + t.Fatal("expected second item") + } + + third, err := getImpl(context.Background(), cache, TestS3Client{}, "foo", "bar", true) + if err != nil { + t.Fatal(err) + } + if third == nil { + t.Fatal("expected third item") + } + + if third == second { + t.Errorf("expected third item (%v) to be different to second item (%v)", third, second) + } +} + var owner = types.Owner{ DisplayName: sources.PtrString("dylan"), ID: sources.PtrString("id"), @@ -489,6 +525,147 @@ func (t TestS3Client) GetBucketWebsite(ctx context.Context, params *s3.GetBucket }, nil } +type TestS3FailClient struct{} + +func (t TestS3FailClient) ListBuckets(ctx context.Context, params *s3.ListBucketsInput, optFns ...func(*s3.Options)) (*s3.ListBucketsOutput, error) { + return nil, errors.New("failed to list buckets") +} + +func (t TestS3FailClient) GetBucketAcl(ctx context.Context, params *s3.GetBucketAclInput, optFns ...func(*s3.Options)) (*s3.GetBucketAclOutput, error) { + return nil, errors.New("failed to get bucket ACL") +} +func (t TestS3FailClient) GetBucketAnalyticsConfiguration(ctx context.Context, params *s3.GetBucketAnalyticsConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketAnalyticsConfigurationOutput, error) { + return nil, errors.New("failed to get bucket ACL") +} + +func (t TestS3FailClient) GetBucketCors(ctx context.Context, params *s3.GetBucketCorsInput, optFns ...func(*s3.Options)) (*s3.GetBucketCorsOutput, error) { + return nil, errors.New("failed to get bucket CORS") +} + +func (t TestS3FailClient) GetBucketEncryption(ctx context.Context, params *s3.GetBucketEncryptionInput, optFns ...func(*s3.Options)) (*s3.GetBucketEncryptionOutput, error) { + return nil, errors.New("failed to get bucket CORS") +} + +func (t TestS3FailClient) GetBucketIntelligentTieringConfiguration(ctx context.Context, params *s3.GetBucketIntelligentTieringConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketIntelligentTieringConfigurationOutput, error) { + return nil, errors.New("failed to get bucket CORS") +} + +func (t TestS3FailClient) GetBucketInventoryConfiguration(ctx context.Context, params *s3.GetBucketInventoryConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketInventoryConfigurationOutput, error) { + return nil, errors.New("failed to get bucket CORS") +} + +func (t TestS3FailClient) GetBucketLifecycleConfiguration(ctx context.Context, params *s3.GetBucketLifecycleConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketLifecycleConfigurationOutput, error) { + return nil, errors.New("failed to get bucket lifecycle configuration") +} + +func (t TestS3FailClient) GetBucketLocation(ctx context.Context, params *s3.GetBucketLocationInput, optFns ...func(*s3.Options)) (*s3.GetBucketLocationOutput, error) { + return nil, errors.New("failed to get bucket location") +} + +func (t TestS3FailClient) GetBucketLogging(ctx context.Context, params *s3.GetBucketLoggingInput, optFns ...func(*s3.Options)) (*s3.GetBucketLoggingOutput, error) { + return nil, errors.New("failed to get bucket logging") +} + +func (t TestS3FailClient) GetBucketMetricsConfiguration(ctx context.Context, params *s3.GetBucketMetricsConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketMetricsConfigurationOutput, error) { + return nil, errors.New("failed to get bucket logging") +} + +func (t TestS3FailClient) GetBucketNotificationConfiguration(ctx context.Context, params *s3.GetBucketNotificationConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketNotificationConfigurationOutput, error) { + return nil, errors.New("failed to get bucket notification configuration") +} + +func (t TestS3FailClient) GetBucketOwnershipControls(ctx context.Context, params *s3.GetBucketOwnershipControlsInput, optFns ...func(*s3.Options)) (*s3.GetBucketOwnershipControlsOutput, error) { + return nil, errors.New("failed to get bucket policy") +} + +func (t TestS3FailClient) GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyOutput, error) { + return nil, errors.New("failed to get bucket policy") +} + +func (t TestS3FailClient) GetBucketPolicyStatus(ctx context.Context, params *s3.GetBucketPolicyStatusInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyStatusOutput, error) { + return nil, errors.New("failed to get bucket policy") +} + +func (t TestS3FailClient) GetBucketReplication(ctx context.Context, params *s3.GetBucketReplicationInput, optFns ...func(*s3.Options)) (*s3.GetBucketReplicationOutput, error) { + return nil, errors.New("failed to get bucket replication") +} + +func (t TestS3FailClient) GetBucketRequestPayment(ctx context.Context, params *s3.GetBucketRequestPaymentInput, optFns ...func(*s3.Options)) (*s3.GetBucketRequestPaymentOutput, error) { + return nil, errors.New("failed to get bucket request payment") +} + +func (t TestS3FailClient) GetBucketTagging(ctx context.Context, params *s3.GetBucketTaggingInput, optFns ...func(*s3.Options)) (*s3.GetBucketTaggingOutput, error) { + return nil, errors.New("failed to get bucket tagging") +} + +func (t TestS3FailClient) GetBucketVersioning(ctx context.Context, params *s3.GetBucketVersioningInput, optFns ...func(*s3.Options)) (*s3.GetBucketVersioningOutput, error) { + return nil, errors.New("failed to get bucket versioning") +} + +func (t TestS3FailClient) GetBucketWebsite(ctx context.Context, params *s3.GetBucketWebsiteInput, optFns ...func(*s3.Options)) (*s3.GetBucketWebsiteOutput, error) { + return nil, errors.New("failed to get bucket website") +} + +func (t TestS3FailClient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + return nil, errors.New("failed to get object") +} + +func (t TestS3FailClient) HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error) { + return nil, errors.New("failed to head bucket") +} + +func (t TestS3FailClient) HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { + return nil, errors.New("failed to head object") +} + +func (t TestS3FailClient) PutBucketAcl(ctx context.Context, params *s3.PutBucketAclInput, optFns ...func(*s3.Options)) (*s3.PutBucketAclOutput, error) { + return nil, errors.New("failed to put bucket ACL") +} + +func (t TestS3FailClient) PutBucketCors(ctx context.Context, params *s3.PutBucketCorsInput, optFns ...func(*s3.Options)) (*s3.PutBucketCorsOutput, error) { + return nil, errors.New("failed to put bucket CORS") +} + +func (t TestS3FailClient) PutBucketLifecycleConfiguration(ctx context.Context, params *s3.PutBucketLifecycleConfigurationInput, optFns ...func(*s3.Options)) (*s3.PutBucketLifecycleConfigurationOutput, error) { + return nil, errors.New("failed to put bucket lifecycle configuration") +} + +func (t TestS3FailClient) PutBucketLogging(ctx context.Context, params *s3.PutBucketLoggingInput, optFns ...func(*s3.Options)) (*s3.PutBucketLoggingOutput, error) { + return nil, errors.New("failed to put bucket logging") +} + +func (t TestS3FailClient) PutBucketNotificationConfiguration(ctx context.Context, params *s3.PutBucketNotificationConfigurationInput, optFns ...func(*s3.Options)) (*s3.PutBucketNotificationConfigurationOutput, error) { + return nil, errors.New("failed to put bucket notification configuration") +} + +func (t TestS3FailClient) PutBucketPolicy(ctx context.Context, params *s3.PutBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.PutBucketPolicyOutput, error) { + return nil, errors.New("failed to put bucket policy") +} + +func (t TestS3FailClient) PutBucketReplication(ctx context.Context, params *s3.PutBucketReplicationInput, optFns ...func(*s3.Options)) (*s3.PutBucketReplicationOutput, error) { + return nil, errors.New("failed to put bucket replication") +} + +func (t TestS3FailClient) PutBucketRequestPayment(ctx context.Context, params *s3.PutBucketRequestPaymentInput, optFns ...func(*s3.Options)) (*s3.PutBucketRequestPaymentOutput, error) { + return nil, errors.New("failed to put bucket request payment") +} + +func (t TestS3FailClient) PutBucketTagging(ctx context.Context, params *s3.PutBucketTaggingInput, optFns ...func(*s3.Options)) (*s3.PutBucketTaggingOutput, error) { + return nil, errors.New("failed to put bucket tagging") +} + +func (t TestS3FailClient) PutBucketVersioning(ctx context.Context, params *s3.PutBucketVersioningInput, optFns ...func(*s3.Options)) (*s3.PutBucketVersioningOutput, error) { + return nil, errors.New("failed to put bucket versioning") +} + +func (t TestS3FailClient) PutBucketWebsite(ctx context.Context, params *s3.PutBucketWebsiteInput, optFns ...func(*s3.Options)) (*s3.PutBucketWebsiteOutput, error) { + return nil, errors.New("failed to put bucket website") +} + +func (t TestS3FailClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + return nil, errors.New("failed to put object") +} + func TestNewS3Source(t *testing.T) { config, account, _ := sources.GetAutoConfig(t) diff --git a/sources/util.go b/sources/util.go index 4971f74e..dc282b23 100644 --- a/sources/util.go +++ b/sources/util.go @@ -143,7 +143,7 @@ func (e E2ETest) Run(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), e.Timeout) defer cancel() - items, err := searchSrc.Search(ctx, scope, *e.GoodSearchQuery) + items, err := searchSrc.Search(ctx, scope, *e.GoodSearchQuery, false) if err != nil { t.Error(err) @@ -169,7 +169,7 @@ func (e E2ETest) Run(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), e.Timeout) defer cancel() - items, err := e.Source.List(ctx, scope) + items, err := e.Source.List(ctx, scope, false) if err != nil { t.Error(err) @@ -205,7 +205,7 @@ func (e E2ETest) Run(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), e.Timeout) defer cancel() - item, err := e.Source.Get(ctx, scope, query) + item, err := e.Source.Get(ctx, scope, query, false) if err != nil { t.Fatal(err) @@ -230,7 +230,7 @@ func (e E2ETest) Run(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), e.Timeout) defer cancel() - _, err := e.Source.Get(ctx, scope, "this is a known bad get query") + _, err := e.Source.Get(ctx, scope, "this is a known bad get query", false) if err == nil { t.Error("expected error, got nil") diff --git a/tracing/tracing.go b/tracing/tracing.go index 20e4a0ec..12ec623d 100644 --- a/tracing/tracing.go +++ b/tracing/tracing.go @@ -6,6 +6,7 @@ import ( "os" "time" + "github.com/MrAlias/otel-schema-utils/schema" "github.com/getsentry/sentry-go" log "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -16,7 +17,7 @@ import ( "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/resource" sdktrace "go.opentelemetry.io/otel/sdk/trace" - semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + semconv "go.opentelemetry.io/otel/semconv/v1.21.0" ) // for stdout debugging of traces @@ -38,23 +39,42 @@ import ( func tracingResource() *resource.Resource { // Identify your application using resource detection - detectors := []resource.Detector{} + resources := []*resource.Resource{} // the EC2 detector takes ~10s to time out outside EC2 // disable it if we're running from a git checkout _, err := os.Stat(".git") if os.IsNotExist(err) { - detectors = append(detectors, ec2.NewResourceDetector()) + ec2Res, err := resource.New(context.Background(), resource.WithDetectors(ec2.NewResourceDetector())) + if err != nil { + log.WithError(err).Error("error initialising EC2 resource detector") + return nil + } + resources = append(resources, ec2Res) } - res, err := resource.New(context.Background(), - resource.WithDetectors(detectors...), - // Keep the default detectors + // Needs https://github.com/open-telemetry/opentelemetry-go-contrib/issues/1856 fixed first + // // the EKS detector is temperamental and doesn't like running outside of kube + // // hence we need to keep it from running when we know there's no kube + // if !viper.GetBool("disable-kube") { + // // Use the AWS resource detector to detect information about the runtime environment + // detectors = append(detectors, eks.NewResourceDetector()) + // } + + hostRes, err := resource.New(context.Background(), resource.WithHost(), resource.WithOS(), resource.WithProcess(), resource.WithContainer(), resource.WithTelemetrySDK(), + ) + if err != nil { + log.WithError(err).Error("error initialising host resource") + return nil + } + resources = append(resources, hostRes) + + localRes, err := resource.New(context.Background(), resource.WithSchemaURL(semconv.SchemaURL), // Add your own custom attributes to identify your application resource.WithAttributes( @@ -63,7 +83,16 @@ func tracingResource() *resource.Resource { ), ) if err != nil { - log.Errorf("resource.New: %v", err) + log.WithError(err).Error("error initialising local resource") + return nil + } + resources = append(resources, localRes) + + conv := schema.NewConverter(schema.NewLocalClient()) + res, err := conv.MergeResources(context.Background(), semconv.SchemaURL, resources...) + + if err != nil { + log.WithError(err).Error("error merging resource") return nil } return res