From a22d8ee29bac0e4ddf576f0ba1c75fe3198bdb3a Mon Sep 17 00:00:00 2001 From: Kleonikos Kyriakis Date: Fri, 17 May 2024 11:56:53 +0300 Subject: [PATCH] Add ci workflow + LICENSE (#20) * tests & refactorings --- .github/workflows/check-clean-branch.sh | 5 + .github/workflows/ci.yml | 66 ++++ .golangci.yml | 57 ++++ LICENSE | 29 ++ LICENSE.header | 2 + cmd/camino-messenger-bot/main.go | 4 +- config/config.go | 12 +- config/flags.go | 5 +- examples/rpc/client.go | 14 +- examples/rpc/partner-plugin/server.go | 19 +- go.mod | 4 +- internal/app/app.go | 54 ++- internal/compression/compress.go | 4 +- internal/compression/decompress.go | 8 +- internal/compression/mock_decompress.go | 54 +++ internal/matrix/matrix_compressor.go | 20 +- internal/matrix/matrix_compressor_test.go | 25 +- internal/matrix/matrix_decompressor.go | 62 ---- internal/matrix/matrix_messenger.go | 20 +- internal/matrix/mock_room_handler.go | 118 +++++++ internal/matrix/msg_assembler.go | 77 ++++- internal/matrix/msg_assembler_test.go | 196 +++++++++++ internal/matrix/room_handler.go | 49 ++- internal/matrix/room_handler_test.go | 155 +++++++++ internal/matrix/types.go | 77 +++-- internal/messaging/messenger.go | 5 + internal/messaging/mock_list_grpc.pb.go | 62 ++++ internal/messaging/mock_messenger.go | 111 ++++++ internal/messaging/mock_service_registry.go | 68 ++++ internal/messaging/noop_response_handler.go | 7 +- internal/messaging/processor.go | 73 ++-- internal/messaging/processor_test.go | 360 ++++++++++++++++++++ internal/messaging/response_handler.go | 126 ++++--- internal/messaging/service.go | 166 ++++----- internal/messaging/service_registry.go | 23 +- internal/messaging/types.go | 105 +++--- internal/metadata/metadata.go | 7 +- internal/rpc/client/client.go | 7 +- internal/rpc/server/server.go | 85 +++-- internal/tracing/exporter.go | 12 +- internal/tracing/nooptracer.go | 4 +- internal/tracing/tracer.go | 1 + internal/tvm/client.go | 2 +- scripts/build_test.sh | 13 + scripts/lint.sh | 26 ++ scripts/mock.gen.sh | 27 ++ scripts/mocks.mockgen.txt | 5 + utils/tls/tls.go | 4 +- 48 files changed, 1920 insertions(+), 515 deletions(-) create mode 100755 .github/workflows/check-clean-branch.sh create mode 100644 .github/workflows/ci.yml create mode 100644 .golangci.yml create mode 100644 LICENSE create mode 100644 LICENSE.header create mode 100644 internal/compression/mock_decompress.go delete mode 100644 internal/matrix/matrix_decompressor.go create mode 100644 internal/matrix/mock_room_handler.go create mode 100644 internal/matrix/msg_assembler_test.go create mode 100644 internal/matrix/room_handler_test.go create mode 100644 internal/messaging/mock_list_grpc.pb.go create mode 100644 internal/messaging/mock_messenger.go create mode 100644 internal/messaging/mock_service_registry.go create mode 100644 internal/messaging/processor_test.go create mode 100755 scripts/build_test.sh create mode 100755 scripts/lint.sh create mode 100755 scripts/mock.gen.sh create mode 100644 scripts/mocks.mockgen.txt diff --git a/.github/workflows/check-clean-branch.sh b/.github/workflows/check-clean-branch.sh new file mode 100755 index 00000000..578c3fba --- /dev/null +++ b/.github/workflows/check-clean-branch.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -euo pipefail + +git update-index --really-refresh >> /dev/null +git diff-index --quiet HEAD diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..f9c94154 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,66 @@ +name: CI + +on: + pull_request: + tags-ignore: ["*"] + branches: [c4t, dev] + push: + branches: [c4t, dev] + workflow_dispatch: + +# Cancel ongoing workflow runs if a new one is started +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + go_version: '~1.20.12' + +jobs: + unit: + name: Unit Tests + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Update dependencies + run: git submodule update --init + - uses: actions/setup-go@v3 + with: + go-version: ${{ env.go_version }} + - name: Install libolm + run: sudo apt update && sudo apt-get install -y libolm-dev + - name: build_test + shell: bash + run: ./scripts/build_test.sh + lint: + name: Static Analysis + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Update dependencies + run: git submodule update --init + - uses: actions/setup-go@v3 + with: + go-version: ${{ env.go_version }} + - name: Install libolm + run: sudo apt update && sudo apt-get install -y libolm-dev + - name: GolangCI-Lint + shell: bash + run: ./scripts/lint.sh + go_mod_tidy: + name: Check state of go.mod and go.sum + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Update dependencies + run: git submodule update --init + - uses: actions/setup-go@v3 + with: + go-version: ${{ env.go_version }} + - name: Go Mod Tidy + shell: bash + run: go mod tidy + - name: Check Clean Branch + shell: bash + run: .github/workflows/check-clean-branch.sh + diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..666a0d1c --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,57 @@ +# Copyright (C) 2024, Chain4Travel AG. All rights reserved. +# See the file LICENSE for licensing terms. + +# https://golangci-lint.run/usage/configuration/ +run: + skip-dirs: + - camino-matrix-go + timeout: 5m + +linters: + disable-all: true + enable: + - asciicheck + - bodyclose +# - cyclop TODO enable + - depguard + - dupl + - errcheck + - errorlint + - exportloopref + - goconst + - gocritic + - gofmt + - gofumpt + - goimports + - goprintffuncname + - gosec + - gosimple + - govet + - ineffassign + - misspell + - nakedret + - noctx + - prealloc + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - whitespace + - revive + +linters-settings: + depguard: + rules: + packages: + deny: + - pkg: "io/ioutil" + desc: io/ioutil is deprecated. Use package io or os instead. + - pkg: "github.com/stretchr/testify/assert" + desc: github.com/stretchr/testify/require should be used instead. + - pkg: "github.com/golang/mock/gomock" + desc: go.uber.org/mock/gomock should be used instead. +issues: + # Maximum count of issues with the same text. Set to 0 to disable. Default: 3. + max-same-issues: 0 diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..bcc1420d --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (C) 2024, Chain4Travel AG. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/LICENSE.header b/LICENSE.header new file mode 100644 index 00000000..3bf4882c --- /dev/null +++ b/LICENSE.header @@ -0,0 +1,2 @@ +Copyright (C) 2024, Chain4Travel AG. All rights reserved. +See the file LICENSE for licensing terms. \ No newline at end of file diff --git a/cmd/camino-messenger-bot/main.go b/cmd/camino-messenger-bot/main.go index 46648f09..29430126 100644 --- a/cmd/camino-messenger-bot/main.go +++ b/cmd/camino-messenger-bot/main.go @@ -19,9 +19,11 @@ func main() { defer stop() app, err := app.NewApp(cfg) + if err != nil { + panic(err) + } err = app.Run(ctx) if err != nil { panic(err) } - } diff --git a/config/config.go b/config/config.go index 9824b8ed..2ceb5889 100644 --- a/config/config.go +++ b/config/config.go @@ -58,7 +58,7 @@ type TracingConfig struct { KeyFile string `mapstructure:"tracing_key_file"` } type Config struct { - AppConfig `mapstructure:",squash"` //TODO use nested yaml structure + AppConfig `mapstructure:",squash"` // TODO use nested yaml structure MatrixConfig `mapstructure:",squash"` RPCServerConfig `mapstructure:",squash"` PartnerPluginConfig `mapstructure:",squash"` @@ -86,7 +86,7 @@ func ReadConfig() (*Config, error) { readAppConfig(cfg.AppConfig, fs) readMatrixConfig(cfg.MatrixConfig, fs) readRPCServerConfig(cfg.RPCServerConfig, fs) - readPartnerRpcServerConfig(cfg.PartnerPluginConfig, fs) + readPartnerRPCServerConfig(cfg.PartnerPluginConfig, fs) readMessengerConfig(cfg.ProcessorConfig, fs) readTvmConfig(cfg.TvmConfig, fs) readTracingConfig(cfg.TracingConfig, fs) @@ -99,7 +99,12 @@ func ReadConfig() (*Config, error) { return nil, err } - viper.ReadInConfig() // ignore config-file-reading-errors as we have env vars as fallback configuration + // read configuration file if provided, otherwise rely on env vars + if configFile != "" { + if err := viper.ReadInConfig(); err != nil { + return cfg, err + } + } if err := viper.Unmarshal(cfg); err != nil { return nil, err @@ -110,6 +115,7 @@ func ReadConfig() (*Config, error) { func (i *SupportedRequestTypesFlag) String() string { return "[" + strings.Join(*i, ",") + "]" } + func (i *SupportedRequestTypesFlag) Contains(requestType string) bool { return slices.Contains(*i, requestType) } diff --git a/config/flags.go b/config/flags.go index 0fffa34a..d877d3aa 100644 --- a/config/flags.go +++ b/config/flags.go @@ -6,7 +6,6 @@ func readAppConfig(cfg AppConfig, fs *flag.FlagSet) { fs.BoolVar(&cfg.DeveloperMode, DeveloperMode, false, "Sets developer mode") fs.Var(&cfg.SupportedRequestTypes, SupportedRequestTypesKey, "The list of supported request types") flag.Parse() - } func readMatrixConfig(cfg MatrixConfig, fs *flag.FlagSet) { @@ -20,15 +19,13 @@ func readRPCServerConfig(cfg RPCServerConfig, fs *flag.FlagSet) { fs.BoolVar(&cfg.Unencrypted, RPCUnencryptedKey, false, "Whether the RPC server should be unencrypted") fs.StringVar(&cfg.ServerCertFile, RPCServerCertFileKey, "", "The server certificate file") fs.StringVar(&cfg.ServerKeyFile, RPCServerKeyFileKey, "", "The server key file") - } -func readPartnerRpcServerConfig(cfg PartnerPluginConfig, fs *flag.FlagSet) { +func readPartnerRPCServerConfig(cfg PartnerPluginConfig, fs *flag.FlagSet) { fs.StringVar(&cfg.Host, PartnerPluginHostKey, "", "The partner plugin RPC server host") fs.IntVar(&cfg.Port, PartnerPluginPortKey, 50051, "The partner plugin RPC server port") fs.BoolVar(&cfg.Unencrypted, PartnerPluginUnencryptedKey, false, "Whether the RPC client should initiate an unencrypted connection with the server") fs.StringVar(&cfg.CACertFile, PartnerPluginCAFileKey, "", "The partner plugin RPC server CA certificate file") - } func readMessengerConfig(cfg ProcessorConfig, fs *flag.FlagSet) { diff --git a/examples/rpc/client.go b/examples/rpc/client.go index 68bed3f4..894f99cb 100755 --- a/examples/rpc/client.go +++ b/examples/rpc/client.go @@ -1,7 +1,6 @@ package main import ( - typesv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/types/v1alpha" "context" "encoding/csv" "flag" @@ -12,6 +11,8 @@ import ( "sync" "time" + typesv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/types/v1alpha" + "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/accommodation/v1alpha/accommodationv1alphagrpc" accommodationv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/accommodation/v1alpha" internalmetadata "github.com/chain4travel/camino-messenger-bot/internal/metadata" @@ -74,7 +75,6 @@ func createClientAndRunRequest(i int, ppConfig config.PartnerPluginConfig, sLogg request := &accommodationv1alpha.AccommodationSearchRequest{ Header: nil, SearchParametersGeneric: &typesv1alpha.SearchParameters{ - Currency: typesv1alpha.Currency_CURRENCY_EUR, Language: typesv1alpha.Language_LANGUAGE_UG, Market: 1, MaxOptions: 2, @@ -82,7 +82,7 @@ func createClientAndRunRequest(i int, ppConfig config.PartnerPluginConfig, sLogg Queries: []*accommodationv1alpha.AccommodationSearchQuery{ { SearchParametersAccommodation: &accommodationv1alpha.AccommodationSearchParameters{ - RatePlan: []*typesv1alpha.RatePlan{{RatePlan: "economy"}}, + SupplierCodes: []*typesv1alpha.SupplierProductCode{{SupplierCode: "supplier1"}}, }, }, }, @@ -139,12 +139,11 @@ func addToDataset(counter int64, totalTime int64, resp *accommodationv1alpha.Acc if entry.Key == "request-gateway-request" { lastValue = entry.Value - continue //skip + continue // skip } if entry.Key == "processor-request" { - - //lastValue = entry.Value - continue //skip + // lastValue = entry.Value + continue // skip } fmt.Printf("%d|%s|%s|%d|%.2f\n", entry.Value, entry.Key, resp.Metadata.SearchId.GetValue(), entry.Value-lastValue, float32(entry.Value-lastValue)/float32(totalTime)) @@ -156,6 +155,7 @@ func addToDataset(counter int64, totalTime int64, resp *accommodationv1alpha.Acc loadTestData[counter] = data mu.Unlock() } + func persistToCSV(dataset [][]string) { // Open a new CSV file file, err := os.Create("load_test_data.csv") diff --git a/examples/rpc/partner-plugin/server.go b/examples/rpc/partner-plugin/server.go index 90c068d9..365cf5a1 100644 --- a/examples/rpc/partner-plugin/server.go +++ b/examples/rpc/partner-plugin/server.go @@ -41,7 +41,7 @@ type partnerPlugin struct { transportv1alphagrpc.TransportSearchServiceServer } -func (p *partnerPlugin) Mint(ctx context.Context, request *bookv1alpha.MintRequest) (*bookv1alpha.MintResponse, error) { +func (p *partnerPlugin) Mint(ctx context.Context, _ *bookv1alpha.MintRequest) (*bookv1alpha.MintResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { @@ -64,7 +64,7 @@ func (p *partnerPlugin) Mint(ctx context.Context, request *bookv1alpha.MintReque return &response, nil } -func (p *partnerPlugin) Validation(ctx context.Context, request *bookv1alpha.ValidationRequest) (*bookv1alpha.ValidationResponse, error) { +func (p *partnerPlugin) Validation(ctx context.Context, _ *bookv1alpha.ValidationRequest) (*bookv1alpha.ValidationResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { @@ -82,7 +82,7 @@ func (p *partnerPlugin) Validation(ctx context.Context, request *bookv1alpha.Val return &response, nil } -func (p *partnerPlugin) ActivitySearch(ctx context.Context, request *activityv1alpha.ActivitySearchRequest) (*activityv1alpha.ActivitySearchResponse, error) { +func (p *partnerPlugin) ActivitySearch(ctx context.Context, _ *activityv1alpha.ActivitySearchRequest) (*activityv1alpha.ActivitySearchResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { @@ -98,7 +98,8 @@ func (p *partnerPlugin) ActivitySearch(ctx context.Context, request *activityv1a grpc.SendHeader(ctx, md.ToGrpcMD()) return &response, nil } -func (p *partnerPlugin) AccommodationProductInfo(ctx context.Context, request *accommodationv1alpha.AccommodationProductInfoRequest) (*accommodationv1alpha.AccommodationProductInfoResponse, error) { + +func (p *partnerPlugin) AccommodationProductInfo(ctx context.Context, _ *accommodationv1alpha.AccommodationProductInfoRequest) (*accommodationv1alpha.AccommodationProductInfoResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { @@ -113,7 +114,8 @@ func (p *partnerPlugin) AccommodationProductInfo(ctx context.Context, request *a grpc.SendHeader(ctx, md.ToGrpcMD()) return &response, nil } -func (p *partnerPlugin) AccommodationProductList(ctx context.Context, request *accommodationv1alpha.AccommodationProductListRequest) (*accommodationv1alpha.AccommodationProductListResponse, error) { + +func (p *partnerPlugin) AccommodationProductList(ctx context.Context, _ *accommodationv1alpha.AccommodationProductListRequest) (*accommodationv1alpha.AccommodationProductListResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { @@ -128,7 +130,8 @@ func (p *partnerPlugin) AccommodationProductList(ctx context.Context, request *a grpc.SendHeader(ctx, md.ToGrpcMD()) return &response, nil } -func (p *partnerPlugin) AccommodationSearch(ctx context.Context, request *accommodationv1alpha.AccommodationSearchRequest) (*accommodationv1alpha.AccommodationSearchResponse, error) { + +func (p *partnerPlugin) AccommodationSearch(ctx context.Context, _ *accommodationv1alpha.AccommodationSearchRequest) (*accommodationv1alpha.AccommodationSearchResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { @@ -163,6 +166,7 @@ func (p *partnerPlugin) GetNetworkFee(ctx context.Context, request *networkv1alp grpc.SendHeader(ctx, md.ToGrpcMD()) return &response, nil } + func (p *partnerPlugin) GetPartnerConfiguration(ctx context.Context, request *partnerv1alpha.GetPartnerConfigurationRequest) (*partnerv1alpha.GetPartnerConfigurationResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) @@ -194,7 +198,8 @@ func (p *partnerPlugin) Ping(ctx context.Context, request *pingv1alpha.PingReque PingMessage: fmt.Sprintf("Ping response to [%s] with request ID: %s", request.PingMessage, md.RequestID), }, nil } -func (p *partnerPlugin) TransportSearch(ctx context.Context, request *transportv1alpha.TransportSearchRequest) (*transportv1alpha.TransportSearchResponse, error) { + +func (p *partnerPlugin) TransportSearch(ctx context.Context, _ *transportv1alpha.TransportSearchRequest) (*transportv1alpha.TransportSearchResponse, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { diff --git a/go.mod b/go.mod index 12fd49ab..d29252ad 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.11.2 go.opentelemetry.io/otel/sdk v1.11.2 go.opentelemetry.io/otel/trace v1.11.2 + go.uber.org/mock v0.4.0 go.uber.org/zap v1.26.0 golang.org/x/sync v0.5.0 google.golang.org/grpc v1.59.0 @@ -82,7 +83,6 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.11.2 // indirect go.opentelemetry.io/otel/exporters/zipkin v1.11.2 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect - go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.18.0 // indirect golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect @@ -96,7 +96,7 @@ require ( require ( github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/golang/protobuf v1.5.3 + github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect diff --git a/internal/app/app.go b/internal/app/app.go index 6f44666e..db90e93e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -39,7 +39,7 @@ func NewApp(cfg *config.Config) (*App, error) { logger, _ = cfg.Build() } app.logger = logger.Sugar() - defer logger.Sync() + defer logger.Sync() //nolint:errcheck return app, nil } @@ -52,16 +52,22 @@ func (a *App) Run(ctx context.Context) error { serviceRegistry := messaging.NewServiceRegistry(a.logger) // start rpc client if host is provided, otherwise bot serves as a distributor bot (rpc server) if a.cfg.PartnerPluginConfig.Host != "" { - a.startRPCClient(g, *serviceRegistry, gCtx) + a.startRPCClient(gCtx, g, serviceRegistry) } else { a.logger.Infof("No host for partner plugin provided, bot will serve as a distributor bot.") serviceRegistry.RegisterServices(a.cfg.SupportedRequestTypes, nil) } // start messenger (receiver) - messenger, userIDUpdatedChan := a.startMessenger(g, gCtx) + messenger, userIDUpdatedChan := a.startMessenger(gCtx, g) // initiate tvm client - responseHandler := a.initTVMClient() + tvmClient, err := tvm.NewClient(a.cfg.TvmConfig) // TODO make client init conditional based on provided config + if err != nil { + a.logger.Warn(err) + } + + // create response handler + responseHandler := a.newResponseHandler(tvmClient) // start msg processor msgProcessor := a.startMessageProcessor(ctx, messenger, serviceRegistry, responseHandler, g, userIDUpdatedChan) @@ -75,7 +81,7 @@ func (a *App) Run(ctx context.Context) error { }() // start rpc server - a.startRPCServer(msgProcessor, serviceRegistry, g, gCtx) + a.startRPCServer(gCtx, msgProcessor, serviceRegistry, g) if err := g.Wait(); err != nil { a.logger.Error(err) @@ -83,6 +89,13 @@ func (a *App) Run(ctx context.Context) error { return nil } +func (a *App) newResponseHandler(tvmClient *tvm.Client) messaging.ResponseHandler { + if tvmClient != nil { + return messaging.NewResponseHandler(tvmClient, a.logger) + } + return messaging.NoopResponseHandler{} +} + func (a *App) initTracer() tracing.Tracer { var ( tracer tracing.Tracer @@ -100,21 +113,7 @@ func (a *App) initTracer() tracing.Tracer { return tracer } -func (a *App) initTVMClient() messaging.ResponseHandler { - var responseHandler messaging.ResponseHandler - // TODO make client init conditional based on provided config - tvmClient, err := tvm.NewClient(a.cfg.TvmConfig) - if err != nil { - // do no return error here, let the bot continue - a.logger.Warnf("Failed to create tvm client: %v", err) - responseHandler = messaging.NoopResponseHandler{} - } else { - responseHandler = messaging.NewResponseHandler(tvmClient, a.logger) - } - return responseHandler -} - -func (a *App) startRPCClient(g *errgroup.Group, serviceRegistry messaging.ServiceRegistry, gCtx context.Context) { +func (a *App) startRPCClient(ctx context.Context, g *errgroup.Group, serviceRegistry messaging.ServiceRegistry) { rpcClient := client.NewClient(&a.cfg.PartnerPluginConfig, a.logger) g.Go(func() error { a.logger.Info("Starting gRPC client...") @@ -126,12 +125,12 @@ func (a *App) startRPCClient(g *errgroup.Group, serviceRegistry messaging.Servic return nil }) g.Go(func() error { - <-gCtx.Done() + <-ctx.Done() return rpcClient.Shutdown() }) } -func (a *App) startMessenger(g *errgroup.Group, gCtx context.Context) (messaging.Messenger, chan string) { +func (a *App) startMessenger(ctx context.Context, g *errgroup.Group) (messaging.Messenger, chan string) { messenger := matrix.NewMessenger(&a.cfg.MatrixConfig, a.logger) userIDUpdatedChan := make(chan string) // Channel to pass the userID g.Go(func() error { @@ -144,27 +143,26 @@ func (a *App) startMessenger(g *errgroup.Group, gCtx context.Context) (messaging return nil }) g.Go(func() error { - <-gCtx.Done() + <-ctx.Done() return messenger.StopReceiver() }) return messenger, userIDUpdatedChan } -func (a *App) startRPCServer(msgProcessor messaging.Processor, serviceRegistry *messaging.ServiceRegistry, g *errgroup.Group, gCtx context.Context) { +func (a *App) startRPCServer(ctx context.Context, msgProcessor messaging.Processor, serviceRegistry messaging.ServiceRegistry, g *errgroup.Group) { rpcServer := server.NewServer(&a.cfg.RPCServerConfig, a.logger, a.tracer, msgProcessor, serviceRegistry) g.Go(func() error { a.logger.Info("Starting gRPC server...") - rpcServer.Start() - return nil + return rpcServer.Start() }) g.Go(func() error { - <-gCtx.Done() + <-ctx.Done() rpcServer.Stop() return nil }) } -func (a *App) startMessageProcessor(ctx context.Context, messenger messaging.Messenger, serviceRegistry *messaging.ServiceRegistry, responseHandler messaging.ResponseHandler, g *errgroup.Group, userIDUpdated chan string) messaging.Processor { +func (a *App) startMessageProcessor(ctx context.Context, messenger messaging.Messenger, serviceRegistry messaging.ServiceRegistry, responseHandler messaging.ResponseHandler, g *errgroup.Group, userIDUpdated chan string) messaging.Processor { msgProcessor := messaging.NewProcessor(messenger, a.logger, a.cfg.ProcessorConfig, serviceRegistry, responseHandler) g.Go(func() error { // Wait for userID to be passed diff --git a/internal/compression/compress.go b/internal/compression/compress.go index e8318710..90e57966 100644 --- a/internal/compression/compress.go +++ b/internal/compression/compress.go @@ -21,10 +21,10 @@ var encoder, _ = zstd.NewWriter(nil) // Compressor interface defines basic compression functionality type Compressor[T any, R any] interface { - // Compress takes a byte array as input and returns the compressed data as a byte array Compress(data T) (R, error) } -func Compress(src []byte) []byte { +// CompressBytes takes a byte array as input and returns the compressed data as a byte array +func CompressBytes(src []byte) []byte { return encoder.EncodeAll(src, make([]byte, 0, len(src))) } diff --git a/internal/compression/decompress.go b/internal/compression/decompress.go index 7421c199..00253f3d 100644 --- a/internal/compression/decompress.go +++ b/internal/compression/decompress.go @@ -11,6 +11,12 @@ import ( var decoder, _ = zstd.NewReader(nil) -func Decompress(src []byte) ([]byte, error) { +type Decompressor interface { + Decompress(src []byte) ([]byte, error) +} + +type ZSTDDecompressor struct{} + +func (d *ZSTDDecompressor) Decompress(src []byte) ([]byte, error) { return decoder.DecodeAll(src, nil) } diff --git a/internal/compression/mock_decompress.go b/internal/compression/mock_decompress.go new file mode 100644 index 00000000..f2bd3521 --- /dev/null +++ b/internal/compression/mock_decompress.go @@ -0,0 +1,54 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/internal/compression (interfaces: Decompressor) +// +// Generated by this command: +// +// mockgen -package=compression -destination=internal/compression/mock_decompress.go github.com/chain4travel/camino-messenger-bot/internal/compression Decompressor +// + +// Package compression is a generated GoMock package. +package compression + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockDecompressor is a mock of Decompressor interface. +type MockDecompressor struct { + ctrl *gomock.Controller + recorder *MockDecompressorMockRecorder +} + +// MockDecompressorMockRecorder is the mock recorder for MockDecompressor. +type MockDecompressorMockRecorder struct { + mock *MockDecompressor +} + +// NewMockDecompressor creates a new mock instance. +func NewMockDecompressor(ctrl *gomock.Controller) *MockDecompressor { + mock := &MockDecompressor{ctrl: ctrl} + mock.recorder = &MockDecompressorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDecompressor) EXPECT() *MockDecompressorMockRecorder { + return m.recorder +} + +// Decompress mocks base method. +func (m *MockDecompressor) Decompress(arg0 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decompress", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decompress indicates an expected call of Decompress. +func (mr *MockDecompressorMockRecorder) Decompress(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decompress", reflect.TypeOf((*MockDecompressor)(nil).Decompress), arg0) +} diff --git a/internal/matrix/matrix_compressor.go b/internal/matrix/matrix_compressor.go index 5f461605..ac6a06fc 100644 --- a/internal/matrix/matrix_compressor.go +++ b/internal/matrix/matrix_compressor.go @@ -16,21 +16,21 @@ import ( ) var ( - _ compression.Compressor[messaging.Message, []CaminoMatrixMessage] = (*MatrixChunkingCompressor)(nil) + _ compression.Compressor[messaging.Message, []CaminoMatrixMessage] = (*ChunkingCompressor)(nil) ErrCompressionProducedNoChunks = errors.New("compression produced no chunks") ErrEncodingMsg = errors.New("error while encoding msg for compression") ) -// MatrixChunkingCompressor is a concrete implementation of Compressor with chunking functionality -type MatrixChunkingCompressor struct { +// ChunkingCompressor is a concrete implementation of Compressor with chunking functionality +type ChunkingCompressor struct { maxChunkSize int } -// Compress implements the Compressor interface for MatrixChunkingCompressor -func (c *MatrixChunkingCompressor) Compress(msg messaging.Message) ([]CaminoMatrixMessage, error) { +// Compress implements the Compressor interface for ChunkingCompressor +func (c *ChunkingCompressor) Compress(msg messaging.Message) ([]CaminoMatrixMessage, error) { var matrixMessages []CaminoMatrixMessage - // 1. Compress the message + // 1. CompressBytes the message compressedContent, err := compress(msg) if err != nil { return matrixMessages, err @@ -45,7 +45,8 @@ func (c *MatrixChunkingCompressor) Compress(msg messaging.Message) ([]CaminoMatr // 3. Create CaminoMatrixMessage objects for each chunk return splitCaminoMatrixMsg(msg, splitCompressedContent) } -func (c *MatrixChunkingCompressor) split(bytes []byte) ([][]byte, error) { + +func (c *ChunkingCompressor) split(bytes []byte) ([][]byte, error) { splitCompressedContent := splitByteArray(bytes, c.maxChunkSize) if len(splitCompressedContent) == 0 { @@ -63,10 +64,11 @@ func compress(msg messaging.Message) ([]byte, error) { if err != nil { return nil, fmt.Errorf("%w: %w", ErrEncodingMsg, err) } - return compression.Compress(bytes), nil + return compression.CompressBytes(bytes), nil } + func splitCaminoMatrixMsg(msg messaging.Message, splitCompressedContent [][]byte) ([]CaminoMatrixMessage, error) { - var messages []CaminoMatrixMessage + messages := make([]CaminoMatrixMessage, 0, len(splitCompressedContent)) // add first chunk to messages slice { diff --git a/internal/matrix/matrix_compressor_test.go b/internal/matrix/matrix_compressor_test.go index 4c795615..d39a7566 100644 --- a/internal/matrix/matrix_compressor_test.go +++ b/internal/matrix/matrix_compressor_test.go @@ -6,15 +6,16 @@ package matrix import ( + "testing" + activityv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/activity/v1alpha" "github.com/chain4travel/camino-messenger-bot/internal/messaging" "github.com/chain4travel/camino-messenger-bot/internal/metadata" "github.com/stretchr/testify/require" "maunium.net/go/mautrix/event" - "testing" ) -func TestMatrixChunkingCompressorCompress(t *testing.T) { +func TestChunkingCompressorCompress(t *testing.T) { type args struct { msg messaging.Message maxSize int @@ -38,7 +39,7 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { Type: messaging.ActivitySearchResponse, Content: messaging.MessageContent{ ResponseContent: messaging.ResponseContent{ - ActivitySearchResponse: activityv1alpha.ActivitySearchResponse{ + ActivitySearchResponse: &activityv1alpha.ActivitySearchResponse{ Results: []*activityv1alpha.ActivitySearchResult{ {Info: &activityv1alpha.Activity{ServiceCode: "test"}}, }, @@ -57,6 +58,7 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { NumberOfChunks: 1, ChunkIndex: 0, }, + CompressedContent: []byte{40, 181, 47, 253, 4, 0, 81, 0, 0, 26, 8, 18, 6, 42, 4, 116, 101, 115, 116, 39, 101, 69, 66}, }, }, }, @@ -66,7 +68,7 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { Type: messaging.ActivitySearchResponse, Content: messaging.MessageContent{ ResponseContent: messaging.ResponseContent{ - ActivitySearchResponse: activityv1alpha.ActivitySearchResponse{ + ActivitySearchResponse: &activityv1alpha.ActivitySearchResponse{ Results: []*activityv1alpha.ActivitySearchResult{ {Info: &activityv1alpha.Activity{ServiceCode: "test"}}, }, @@ -85,6 +87,7 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { NumberOfChunks: 1, ChunkIndex: 0, }, + CompressedContent: []byte{40, 181, 47, 253, 4, 0, 81, 0, 0, 26, 8, 18, 6, 42, 4, 116, 101, 115, 116, 39, 101, 69, 66}, }, }, }, @@ -94,7 +97,7 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { Type: messaging.ActivitySearchResponse, Content: messaging.MessageContent{ ResponseContent: messaging.ResponseContent{ - ActivitySearchResponse: activityv1alpha.ActivitySearchResponse{ + ActivitySearchResponse: &activityv1alpha.ActivitySearchResponse{ Results: []*activityv1alpha.ActivitySearchResult{ {Info: &activityv1alpha.Activity{ServiceCode: "test"}}, }, @@ -113,6 +116,7 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { NumberOfChunks: 2, ChunkIndex: 0, }, + CompressedContent: []byte{40, 181, 47, 253, 4, 0, 81, 0, 0, 26, 8, 18, 6, 42, 4, 116, 101, 115, 116, 39, 101, 69}, }, { MessageEventContent: event.MessageEventContent{ @@ -122,22 +126,17 @@ func TestMatrixChunkingCompressorCompress(t *testing.T) { NumberOfChunks: 2, ChunkIndex: 1, }, + CompressedContent: []byte{66}, }, }, }, } for tc, tt := range tests { t.Run(tc, func(t *testing.T) { - c := &MatrixChunkingCompressor{tt.args.maxSize} + c := &ChunkingCompressor{tt.args.maxSize} got, err := c.Compress(tt.args.msg) require.ErrorIs(t, err, tt.err) - require.Equal(t, len(got), len(tt.want)) - for i, msg := range got { - require.Equal(t, msg.MessageEventContent.MsgType, tt.want[i].MsgType) - require.Equal(t, msg.Metadata.NumberOfChunks, tt.want[i].Metadata.NumberOfChunks) - require.Equal(t, msg.Metadata.ChunkIndex, tt.want[i].Metadata.ChunkIndex) - require.NotNil(t, msg.CompressedContent) - } + require.Equal(t, tt.want, got) }) } } diff --git a/internal/matrix/matrix_decompressor.go b/internal/matrix/matrix_decompressor.go deleted file mode 100644 index 8e3d703f..00000000 --- a/internal/matrix/matrix_decompressor.go +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (C) 2022-2023, Chain4Travel AG. All rights reserved. - * See the file LICENSE for licensing terms. - */ - -package matrix - -import ( - "fmt" - "sort" - - "github.com/chain4travel/camino-messenger-bot/internal/compression" - "github.com/chain4travel/camino-messenger-bot/internal/messaging" -) - -func assembleAndDecompressCaminoMatrixMessages(messages []CaminoMatrixMessage) (CaminoMatrixMessage, error) { - var compressedPayloads [][]byte - - // chunks have to be sorted - sort.Sort(ByChunkIndex(messages)) - for _, msg := range messages { - compressedPayloads = append(compressedPayloads, msg.CompressedContent) - } - - // assemble chunks and decompress content - originalContent, err := assembleAndDecompress(compressedPayloads) - if err != nil { - return CaminoMatrixMessage{}, fmt.Errorf("failed to assemble and decompress camino matrix msg: %v", err) - } - - msg := CaminoMatrixMessage{ - MessageEventContent: messages[0].MessageEventContent, - Metadata: messages[0].Metadata, - } - switch messaging.MessageType(msg.MsgType).Category() { - case messaging.Request, - messaging.Response: - msg.UnmarshalContent(originalContent) - default: - return CaminoMatrixMessage{}, fmt.Errorf("could not categorize unknown message type: %v", msg.MsgType) - } - - return msg, nil -} -func assembleAndDecompress(src [][]byte) ([]byte, error) { - return compression.Decompress(assembleByteArray(src)) -} - -func assembleByteArray(src [][]byte) []byte { - totalLength := 0 - for _, slice := range src { - totalLength += len(slice) - } - - result := make([]byte, totalLength) - index := 0 - for _, slice := range src { - copy(result[index:], slice) - index += len(slice) - } - return result -} diff --git a/internal/matrix/matrix_messenger.go b/internal/matrix/matrix_messenger.go index ef558e53..6bea60d7 100644 --- a/internal/matrix/matrix_messenger.go +++ b/internal/matrix/matrix_messenger.go @@ -22,7 +22,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - _ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" //nolint:revive ) var _ messaging.Messenger = (*messenger)(nil) @@ -49,7 +49,7 @@ type messenger struct { compressor compression.Compressor[messaging.Message, []CaminoMatrixMessage] } -func NewMessenger(cfg *config.MatrixConfig, logger *zap.SugaredLogger) *messenger { +func NewMessenger(cfg *config.MatrixConfig, logger *zap.SugaredLogger) messaging.Messenger { c, err := mautrix.NewClient(cfg.Host, "", "") if err != nil { panic(err) @@ -60,11 +60,12 @@ func NewMessenger(cfg *config.MatrixConfig, logger *zap.SugaredLogger) *messenge logger: logger, tracer: otel.GetTracerProvider().Tracer(""), client: client{Client: c}, - roomHandler: NewRoomHandler(c, logger), - msgAssembler: NewMessageAssembler(logger), - compressor: &MatrixChunkingCompressor{maxChunkSize: compression.MaxChunkSize}, + roomHandler: NewRoomHandler(NewClient(c), logger), + msgAssembler: NewMessageAssembler(), + compressor: &ChunkingCompressor{maxChunkSize: compression.MaxChunkSize}, } } + func (m *messenger) Checkpoint() string { return "messenger-gateway" } @@ -83,7 +84,7 @@ func (m *messenger) StartReceiver() (string, error) { _, span := m.tracer.Start(ctx, "messenger.OnC4TMessageReceive", trace.WithSpanKind(trace.SpanKindConsumer), trace.WithAttributes(attribute.String("type", evt.Type.Type))) defer span.End() t := time.Now() - completeMsg, err, completed := m.msgAssembler.AssembleMessage(*msg) + completeMsg, completed, err := m.msgAssembler.AssembleMessage(msg) if err != nil { m.logger.Errorf("failed to assemble message: %v", err) return @@ -114,7 +115,7 @@ func (m *messenger) StartReceiver() (string, error) { } }) - cryptoHelper, err := cryptohelper.NewCryptoHelper(m.client.Client, []byte("meow"), m.cfg.Store) //TODO refactor + cryptoHelper, err := cryptohelper.NewCryptoHelper(m.client.Client, []byte("meow"), m.cfg.Store) // TODO refactor if err != nil { return "", err } @@ -139,7 +140,7 @@ func (m *messenger) StartReceiver() (string, error) { if err != nil { return "", err } - // Set the client crypto helper in order to automatically encrypt outgoing messages + // Set the wrappedClient crypto helper in order to automatically encrypt outgoing messages m.client.Crypto = cryptoHelper m.client.cryptoHelper = cryptoHelper // nikos: we need the struct cause stop method is not available on the interface level @@ -159,6 +160,7 @@ func (m *messenger) StartReceiver() (string, error) { return m.client.UserID.String(), nil } + func (m *messenger) StopReceiver() error { m.logger.Info("Stopping matrix syncer...") if m.client.cancelSync != nil { @@ -191,7 +193,7 @@ func (m *messenger) SendAsync(ctx context.Context, msg messaging.Message) error } func (m *messenger) sendMessageEvents(ctx context.Context, roomID id.RoomID, eventType event.Type, messages []CaminoMatrixMessage) error { - //TODO add retry logic? + // TODO add retry logic? for _, msg := range messages { _, err := m.client.SendMessageEvent(ctx, roomID, eventType, msg) if err != nil { diff --git a/internal/matrix/mock_room_handler.go b/internal/matrix/mock_room_handler.go new file mode 100644 index 00000000..b3555a1d --- /dev/null +++ b/internal/matrix/mock_room_handler.go @@ -0,0 +1,118 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/internal/matrix (interfaces: Client) +// +// Generated by this command: +// +// mockgen -package=matrix -destination=internal/matrix/mock_room_handler.go github.com/chain4travel/camino-messenger-bot/internal/matrix Client +// + +// Package matrix is a generated GoMock package. +package matrix + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + mautrix "maunium.net/go/mautrix" + event "maunium.net/go/mautrix/event" + id "maunium.net/go/mautrix/id" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// CreateRoom mocks base method. +func (m *MockClient) CreateRoom(arg0 context.Context, arg1 *mautrix.ReqCreateRoom) (*mautrix.RespCreateRoom, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateRoom", arg0, arg1) + ret0, _ := ret[0].(*mautrix.RespCreateRoom) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateRoom indicates an expected call of CreateRoom. +func (mr *MockClientMockRecorder) CreateRoom(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRoom", reflect.TypeOf((*MockClient)(nil).CreateRoom), arg0, arg1) +} + +// IsEncrypted mocks base method. +func (m *MockClient) IsEncrypted(arg0 context.Context, arg1 id.RoomID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsEncrypted", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsEncrypted indicates an expected call of IsEncrypted. +func (mr *MockClientMockRecorder) IsEncrypted(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsEncrypted", reflect.TypeOf((*MockClient)(nil).IsEncrypted), arg0, arg1) +} + +// JoinedMembers mocks base method. +func (m *MockClient) JoinedMembers(arg0 context.Context, arg1 id.RoomID) (*mautrix.RespJoinedMembers, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "JoinedMembers", arg0, arg1) + ret0, _ := ret[0].(*mautrix.RespJoinedMembers) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// JoinedMembers indicates an expected call of JoinedMembers. +func (mr *MockClientMockRecorder) JoinedMembers(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "JoinedMembers", reflect.TypeOf((*MockClient)(nil).JoinedMembers), arg0, arg1) +} + +// JoinedRooms mocks base method. +func (m *MockClient) JoinedRooms(arg0 context.Context) (*mautrix.RespJoinedRooms, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "JoinedRooms", arg0) + ret0, _ := ret[0].(*mautrix.RespJoinedRooms) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// JoinedRooms indicates an expected call of JoinedRooms. +func (mr *MockClientMockRecorder) JoinedRooms(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "JoinedRooms", reflect.TypeOf((*MockClient)(nil).JoinedRooms), arg0) +} + +// SendStateEvent mocks base method. +func (m *MockClient) SendStateEvent(arg0 context.Context, arg1 id.RoomID, arg2 event.Type, arg3 string, arg4 any) (*mautrix.RespSendEvent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendStateEvent", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(*mautrix.RespSendEvent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendStateEvent indicates an expected call of SendStateEvent. +func (mr *MockClientMockRecorder) SendStateEvent(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendStateEvent", reflect.TypeOf((*MockClient)(nil).SendStateEvent), arg0, arg1, arg2, arg3, arg4) +} diff --git a/internal/matrix/msg_assembler.go b/internal/matrix/msg_assembler.go index b1f37056..13413231 100644 --- a/internal/matrix/msg_assembler.go +++ b/internal/matrix/msg_assembler.go @@ -6,41 +6,92 @@ package matrix import ( + "errors" + "fmt" + "sort" "sync" - "go.uber.org/zap" + "github.com/chain4travel/camino-messenger-bot/internal/compression" +) + +var ( + ErrDecompressFailed = errors.New("failed to decompress assembled camino matrix msg") + ErrUnmarshalContent = errors.New("failed to unmarshal content") ) type MessageAssembler interface { - AssembleMessage(msg CaminoMatrixMessage) (CaminoMatrixMessage, error, bool) // returns assembled message and true if message is complete. Otherwise, it returns an empty message and false + AssembleMessage(msg *CaminoMatrixMessage) (assembledMsg *CaminoMatrixMessage, complete bool, err error) // returns assembled message and true if message is complete. Otherwise, it returns an empty message and false } type messageAssembler struct { - logger *zap.SugaredLogger - partialMessages map[string][]CaminoMatrixMessage + partialMessages map[string][]*CaminoMatrixMessage + decompressor compression.Decompressor mu sync.RWMutex } -func NewMessageAssembler(logger *zap.SugaredLogger) MessageAssembler { - return &messageAssembler{logger: logger, partialMessages: make(map[string][]CaminoMatrixMessage)} +func NewMessageAssembler() MessageAssembler { + return &messageAssembler{decompressor: &compression.ZSTDDecompressor{}, partialMessages: make(map[string][]*CaminoMatrixMessage)} } -func (a *messageAssembler) AssembleMessage(msg CaminoMatrixMessage) (CaminoMatrixMessage, error, bool) { + +func (a *messageAssembler) AssembleMessage(msg *CaminoMatrixMessage) (*CaminoMatrixMessage, bool, error) { if msg.Metadata.NumberOfChunks == 1 { - decompressedCaminoMsg, err := assembleAndDecompressCaminoMatrixMessages([]CaminoMatrixMessage{msg}) - return decompressedCaminoMsg, err, true + decompressedCaminoMsg, err := a.assembleAndDecompressCaminoMatrixMessages([]*CaminoMatrixMessage{msg}) + return decompressedCaminoMsg, err == nil, err } a.mu.Lock() defer a.mu.Unlock() id := msg.Metadata.RequestID if _, ok := a.partialMessages[id]; !ok { - a.partialMessages[id] = []CaminoMatrixMessage{} + a.partialMessages[id] = []*CaminoMatrixMessage{} } a.partialMessages[id] = append(a.partialMessages[id], msg) if len(a.partialMessages[id]) == int(msg.Metadata.NumberOfChunks) { - decompressedCaminoMsg, err := assembleAndDecompressCaminoMatrixMessages(a.partialMessages[id]) + decompressedCaminoMsg, err := a.assembleAndDecompressCaminoMatrixMessages(a.partialMessages[id]) delete(a.partialMessages, id) - return decompressedCaminoMsg, err, true + return decompressedCaminoMsg, err == nil, err + } + return nil, false, nil +} + +func (a *messageAssembler) assembleAndDecompressCaminoMatrixMessages(messages []*CaminoMatrixMessage) (*CaminoMatrixMessage, error) { + compressedPayloads := make([][]byte, 0, len(messages)) + + // chunks have to be sorted + sort.Sort(ByChunkIndex(messages)) + for _, msg := range messages { + compressedPayloads = append(compressedPayloads, msg.CompressedContent) + } + + // assemble chunks and decompress + originalContent, err := a.decompressor.Decompress(assemble(compressedPayloads)) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrDecompressFailed, err) + } + + msg := CaminoMatrixMessage{ + MessageEventContent: messages[0].MessageEventContent, + Metadata: messages[0].Metadata, + } + err = msg.UnmarshalContent(originalContent) + if err != nil { + return nil, fmt.Errorf("%w: %w %v", ErrUnmarshalContent, err, msg.MsgType) + } + + return &msg, nil +} + +func assemble(src [][]byte) []byte { + totalLength := 0 + for _, slice := range src { + totalLength += len(slice) + } + + result := make([]byte, totalLength) + index := 0 + for _, slice := range src { + copy(result[index:], slice) + index += len(slice) } - return CaminoMatrixMessage{}, nil, false + return result } diff --git a/internal/matrix/msg_assembler_test.go b/internal/matrix/msg_assembler_test.go new file mode 100644 index 00000000..93572ca6 --- /dev/null +++ b/internal/matrix/msg_assembler_test.go @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2024, Chain4Travel AG. All rights reserved. + * See the file LICENSE for licensing terms. + */ + +package matrix + +import ( + "testing" + + "maunium.net/go/mautrix/event" + + activityv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/activity/v1alpha" + "github.com/chain4travel/camino-messenger-bot/internal/compression" + "github.com/chain4travel/camino-messenger-bot/internal/messaging" + "github.com/chain4travel/camino-messenger-bot/internal/metadata" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestAssembleMessage(t *testing.T) { + plainActivitySearchResponseMsg := messaging.Message{ + Type: messaging.ActivitySearchResponse, + Content: messaging.MessageContent{ + ResponseContent: messaging.ResponseContent{ + ActivitySearchResponse: &activityv1alpha.ActivitySearchResponse{ + Results: []*activityv1alpha.ActivitySearchResult{ + {Info: &activityv1alpha.Activity{ServiceCode: "test"}}, + }, + }, + }, + }, + } + type fields struct { + partialMessages map[string][]*CaminoMatrixMessage + } + + type args struct { + msg *CaminoMatrixMessage + } + + // mocks + ctrl := gomock.NewController(t) + mockedDecompressor := compression.NewMockDecompressor(ctrl) + + tests := map[string]struct { + fields fields + args args + prepare func() + want *CaminoMatrixMessage + isComplete bool + err error + }{ + "err: decoder failed to decompress": { + args: args{ + msg: &CaminoMatrixMessage{ + Metadata: metadata.Metadata{ + RequestID: "test", + NumberOfChunks: 1, + }, + }, + }, + prepare: func() { + mockedDecompressor.EXPECT().Decompress(gomock.Any()).Times(1).Return(nil, ErrDecompressFailed) + }, + isComplete: false, + err: ErrDecompressFailed, + }, + "err: unknown message type": { + args: args{ + msg: &CaminoMatrixMessage{ + Metadata: metadata.Metadata{ + RequestID: "test", + NumberOfChunks: 1, + }, + }, + }, + prepare: func() { + mockedDecompressor.EXPECT().Decompress(gomock.Any()).Times(1).Return([]byte{}, nil) + }, + isComplete: false, + err: ErrUnmarshalContent, + }, + "empty input": { + fields: fields{ + partialMessages: map[string][]*CaminoMatrixMessage{}, + }, + args: args{ + msg: &CaminoMatrixMessage{}, + }, + isComplete: false, + err: nil, + }, + "partial message delivery [metadata number fo chunks do not match provided messages]": { + fields: fields{ + partialMessages: map[string][]*CaminoMatrixMessage{}, + }, + args: args{ + msg: &CaminoMatrixMessage{ + Metadata: metadata.Metadata{ + RequestID: "test", + NumberOfChunks: 2, + }, + }, + }, + isComplete: false, + err: nil, + }, + "success: single chunk message": { + fields: fields{ + partialMessages: map[string][]*CaminoMatrixMessage{}, + }, + args: args{ + msg: &CaminoMatrixMessage{ + MessageEventContent: event.MessageEventContent{ + MsgType: event.MessageType(messaging.ActivitySearchResponse), + }, + Metadata: metadata.Metadata{ + RequestID: "id", + NumberOfChunks: 1, + }, + }, // last message + }, + prepare: func() { + msg := plainActivitySearchResponseMsg + msgBytes, err := msg.MarshalContent() + require.NoError(t, err) + mockedDecompressor.EXPECT().Decompress(gomock.Any()).Times(1).Return(msgBytes, nil) + }, + want: &CaminoMatrixMessage{ + Metadata: metadata.Metadata{ + RequestID: "id", + NumberOfChunks: 1, + }, + MessageEventContent: event.MessageEventContent{ + MsgType: event.MessageType(messaging.ActivitySearchResponse), + }, + Content: plainActivitySearchResponseMsg.Content, + }, + isComplete: true, + err: nil, + }, + "success: multi-chunk message": { + fields: fields{ + partialMessages: map[string][]*CaminoMatrixMessage{"id": { + // only 2 chunks because the last one is passed as the last argument triggering the call of AssembleMessage + // msgType is necessary only for 1st chunk + {MessageEventContent: event.MessageEventContent{MsgType: event.MessageType(messaging.ActivitySearchResponse)}}, {}, + }}, + }, + args: args{ + msg: &CaminoMatrixMessage{ + Metadata: metadata.Metadata{ + RequestID: "id", + NumberOfChunks: 3, + }, + }, // last message + }, + prepare: func() { + msg := plainActivitySearchResponseMsg + msgBytes, err := msg.MarshalContent() + require.NoError(t, err) + mockedDecompressor.EXPECT().Decompress(gomock.Any()).Times(1).Return(msgBytes, nil) + }, + want: &CaminoMatrixMessage{ + MessageEventContent: event.MessageEventContent{ + MsgType: event.MessageType(messaging.ActivitySearchResponse), + }, + Content: plainActivitySearchResponseMsg.Content, + }, + isComplete: true, + err: nil, + }, + } + for tc, tt := range tests { + t.Run(tc, func(t *testing.T) { + a := &messageAssembler{ + partialMessages: tt.fields.partialMessages, + decompressor: mockedDecompressor, + } + if tt.prepare != nil { + tt.prepare() + } + got, isComplete, err := a.AssembleMessage(tt.args.msg) + require.ErrorIs(t, err, tt.err) + require.Equal(t, tt.isComplete, isComplete, "AssembleMessage() isComplete = %v, expRoomID %v", isComplete, tt.isComplete) + + // Reset the response content to avoid comparisons of pb fields like sizeCache + if tt.want != nil && got != nil { + tt.want.Content.ResponseContent.ActivitySearchResponse.Reset() + got.Content.ResponseContent.ActivitySearchResponse.Reset() + } + require.Equal(t, tt.want, got, "AssembleMessage() got = %v, expRoomID %v", got, tt.want) + }) + } +} diff --git a/internal/matrix/room_handler.go b/internal/matrix/room_handler.go index 7dd9093a..bb4946a9 100644 --- a/internal/matrix/room_handler.go +++ b/internal/matrix/room_handler.go @@ -10,47 +10,63 @@ import ( "maunium.net/go/mautrix/id" ) +type Client interface { + IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) + CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (*mautrix.RespCreateRoom, error) + SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content interface{}) (*mautrix.RespSendEvent, error) + JoinedRooms(ctx context.Context) (*mautrix.RespJoinedRooms, error) + JoinedMembers(ctx context.Context, roomID id.RoomID) (*mautrix.RespJoinedMembers, error) +} + +// wrappedClient is a wrapper around mautrix.Client to abstract away concrete implementations and thus facilitate testing and mocking +type wrappedClient struct { + *mautrix.Client +} + +func (c *wrappedClient) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { + return c.StateStore.IsEncrypted(ctx, roomID) +} + +func NewClient(mautrixClient *mautrix.Client) Client { + return &wrappedClient{mautrixClient} +} + type RoomHandler interface { - GetOrCreateRoomForRecipient(context.Context, id.UserID) (id.RoomID, error) - CreateRoomAndInviteUser(context.Context, id.UserID) (id.RoomID, error) - EnableEncryptionForRoom(context.Context, id.RoomID) error - GetEncryptedRoomForRecipient(context.Context, id.UserID) (id.RoomID, bool) + GetOrCreateRoomForRecipient(ctx context.Context, recipient id.UserID) (id.RoomID, error) } type roomHandler struct { - client *mautrix.Client + client Client logger *zap.SugaredLogger rooms map[id.UserID]id.RoomID mu sync.RWMutex } -func NewRoomHandler(client *mautrix.Client, logger *zap.SugaredLogger) RoomHandler { +func NewRoomHandler(client Client, logger *zap.SugaredLogger) RoomHandler { return &roomHandler{client: client, logger: logger, rooms: make(map[id.UserID]id.RoomID)} } func (r *roomHandler) GetOrCreateRoomForRecipient(ctx context.Context, recipient id.UserID) (id.RoomID, error) { - // check if room already established with recipient - roomID, found := r.GetEncryptedRoomForRecipient(ctx, recipient) + roomID, found := r.getEncryptedRoomForRecipient(ctx, recipient) var err error // if not create room and invite recipient if !found { - roomID, err = r.CreateRoomAndInviteUser(ctx, recipient) + roomID, err = r.createRoomAndInviteUser(ctx, recipient) if err != nil { return "", err } // enable encryption for room - err = r.EnableEncryptionForRoom(ctx, roomID) + err = r.enableEncryptionForRoom(ctx, roomID) if err != nil { return "", err } } - // return room id return roomID, nil } -func (r *roomHandler) CreateRoomAndInviteUser(ctx context.Context, userID id.UserID) (id.RoomID, error) { +func (r *roomHandler) createRoomAndInviteUser(ctx context.Context, userID id.UserID) (id.RoomID, error) { r.logger.Debugf("Creating room and inviting user %v", userID) req := mautrix.ReqCreateRoom{ Visibility: "private", @@ -65,15 +81,15 @@ func (r *roomHandler) CreateRoomAndInviteUser(ctx context.Context, userID id.Use return resp.RoomID, nil } -func (r *roomHandler) EnableEncryptionForRoom(ctx context.Context, roomID id.RoomID) error { +func (r *roomHandler) enableEncryptionForRoom(ctx context.Context, roomID id.RoomID) error { r.logger.Debugf("Enabling encryption for room %s", roomID) _, err := r.client.SendStateEvent(ctx, roomID, event.StateEncryption, "", event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1}) return err } -func (r *roomHandler) GetEncryptedRoomForRecipient(ctx context.Context, recipient id.UserID) (id.RoomID, bool) { - roomID := r.fetchCachedRoom(recipient) +func (r *roomHandler) getEncryptedRoomForRecipient(ctx context.Context, recipient id.UserID) (roomID id.RoomID, found bool) { + roomID = r.fetchCachedRoom(recipient) if roomID != "" { return roomID, true } @@ -83,7 +99,7 @@ func (r *roomHandler) GetEncryptedRoomForRecipient(ctx context.Context, recipien return "", false } for _, roomID := range rooms.JoinedRooms { - if encrypted, err := r.client.StateStore.IsEncrypted(ctx, roomID); err != nil || !encrypted { + if encrypted, err := r.client.IsEncrypted(ctx, roomID); err != nil || !encrypted { continue } members, err := r.client.JoinedMembers(ctx, roomID) @@ -105,6 +121,7 @@ func (r *roomHandler) fetchCachedRoom(recipient id.UserID) id.RoomID { defer r.mu.RUnlock() return r.rooms[recipient] } + func (r *roomHandler) cacheRoom(recipient id.UserID, roomID id.RoomID) { r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/matrix/room_handler_test.go b/internal/matrix/room_handler_test.go new file mode 100644 index 00000000..16d6b6c2 --- /dev/null +++ b/internal/matrix/room_handler_test.go @@ -0,0 +1,155 @@ +/* + * Copyright (C) 2024, Chain4Travel AG. All rights reserved. + * See the file LICENSE for licensing terms. + */ + +package matrix + +import ( + "context" + "errors" + "testing" + + "maunium.net/go/mautrix/event" + + "maunium.net/go/mautrix" + + "go.uber.org/mock/gomock" + + "github.com/stretchr/testify/require" + + "go.uber.org/zap" + "maunium.net/go/mautrix/id" +) + +func TestGetOrCreateRoomForRecipient(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockRoomClient := NewMockClient(mockCtrl) + defer mockCtrl.Finish() + + userID := id.UserID("userID") + roomID := id.RoomID("roomID") + newRoomID := id.RoomID("newRoomID") + + errCreateRoomFailed := errors.New("create-room-failed") + errEnableEncryptionFailed := errors.New("enable-encryption-failed") + + type fields struct { + rooms map[id.UserID]id.RoomID + } + type args struct { + recipient id.UserID + } + tests := map[string]struct { + fields fields + args args + want id.RoomID + mocks func(r *roomHandler) + err error + }{ + "err: create new encrypted room fails": { + fields: fields{ + rooms: map[id.UserID]id.RoomID{}, + }, + mocks: func(*roomHandler) { + mockRoomClient.EXPECT().JoinedRooms(gomock.Any()).Times(1).Return(&mautrix.RespJoinedRooms{JoinedRooms: []id.RoomID{roomID}}, nil) + mockRoomClient.EXPECT().IsEncrypted(gomock.Any(), roomID).Times(1).Return(false, nil) + mockRoomClient.EXPECT().CreateRoom(gomock.Any(), &mautrix.ReqCreateRoom{ + Visibility: "private", + Preset: "private_chat", + Invite: []id.UserID{userID}, + }).Times(1).Return(nil, errCreateRoomFailed) + }, + args: args{recipient: userID}, + err: errCreateRoomFailed, + }, + "err: room exists but is unencrypted so create new encrypted room created but enable encryption fails": { //nolint:dupl + fields: fields{ + rooms: map[id.UserID]id.RoomID{}, + }, + mocks: func(*roomHandler) { + mockRoomClient.EXPECT().JoinedRooms(gomock.Any()).Times(1).Return(&mautrix.RespJoinedRooms{JoinedRooms: []id.RoomID{roomID}}, nil) + mockRoomClient.EXPECT().IsEncrypted(gomock.Any(), roomID).Times(1).Return(false, nil) + mockRoomClient.EXPECT().CreateRoom(gomock.Any(), &mautrix.ReqCreateRoom{ + Visibility: "private", + Preset: "private_chat", + Invite: []id.UserID{userID}, + }).Times(1).Return(&mautrix.RespCreateRoom{RoomID: newRoomID}, nil) + mockRoomClient.EXPECT().SendStateEvent(gomock.Any(), newRoomID, event.StateEncryption, "", + event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1}).Times(1).Return(nil, errEnableEncryptionFailed) + }, + args: args{recipient: userID}, + err: errEnableEncryptionFailed, + }, + "success: room already established and cached": { + fields: fields{ + rooms: map[id.UserID]id.RoomID{userID: roomID}, + }, + args: args{recipient: userID}, + want: roomID, + }, + "success: room already established but not cached": { + fields: fields{ + rooms: map[id.UserID]id.RoomID{}, + }, + mocks: func(*roomHandler) { + mockRoomClient.EXPECT().JoinedRooms(gomock.Any()).Times(1).Return(&mautrix.RespJoinedRooms{JoinedRooms: []id.RoomID{roomID}}, nil) + mockRoomClient.EXPECT().IsEncrypted(gomock.Any(), roomID).Times(1).Return(true, nil) + mockRoomClient.EXPECT().JoinedMembers(gomock.Any(), roomID).Times(1).Return(&mautrix.RespJoinedMembers{Joined: map[id.UserID]mautrix.JoinedMember{userID: {}}}, nil) + }, + args: args{recipient: userID}, + want: roomID, + }, + "success: room exists but recipient is not member so create new encrypted room created and invite user": { //nolint:dupl + fields: fields{ + rooms: map[id.UserID]id.RoomID{}, + }, + mocks: func(*roomHandler) { + mockRoomClient.EXPECT().JoinedRooms(gomock.Any()).Times(1).Return(&mautrix.RespJoinedRooms{JoinedRooms: []id.RoomID{}}, nil) + mockRoomClient.EXPECT().CreateRoom(gomock.Any(), &mautrix.ReqCreateRoom{ + Visibility: "private", + Preset: "private_chat", + Invite: []id.UserID{userID}, + }).Times(1).Return(&mautrix.RespCreateRoom{RoomID: newRoomID}, nil) + mockRoomClient.EXPECT().SendStateEvent(gomock.Any(), newRoomID, event.StateEncryption, "", + event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1}).Times(1).Return(nil, nil) + }, + args: args{recipient: userID}, + want: newRoomID, + }, + "success: room exists but is unencrypted so create new encrypted room created and invite user": { //nolint:dupl + fields: fields{ + rooms: map[id.UserID]id.RoomID{}, + }, + mocks: func(*roomHandler) { + mockRoomClient.EXPECT().JoinedRooms(gomock.Any()).Times(1).Return(&mautrix.RespJoinedRooms{JoinedRooms: []id.RoomID{roomID}}, nil) + mockRoomClient.EXPECT().IsEncrypted(gomock.Any(), roomID).Times(1).Return(false, nil) + mockRoomClient.EXPECT().CreateRoom(gomock.Any(), &mautrix.ReqCreateRoom{ + Visibility: "private", + Preset: "private_chat", + Invite: []id.UserID{userID}, + }).Times(1).Return(&mautrix.RespCreateRoom{RoomID: newRoomID}, nil) + mockRoomClient.EXPECT().SendStateEvent(gomock.Any(), newRoomID, event.StateEncryption, "", + event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1}).Times(1).Return(nil, nil) + }, + args: args{recipient: userID}, + want: newRoomID, + }, + } + for tc, tt := range tests { + t.Run(tc, func(t *testing.T) { + r := &roomHandler{ + client: mockRoomClient, + logger: zap.NewNop().Sugar(), + rooms: tt.fields.rooms, + } + if tt.mocks != nil { + tt.mocks(r) + } + + got, err := r.GetOrCreateRoomForRecipient(context.Background(), tt.args.recipient) + require.ErrorIs(t, err, tt.err, "GetOrCreateRoomForRecipient() error = %w, wantErr %w", err, tt.err) + require.Equal(t, got, tt.want, "GetOrCreateRoomForRecipient() got = %v, expRoomID %v", got, tt.want) + }) + } +} diff --git a/internal/matrix/types.go b/internal/matrix/types.go index e16d19c2..dcf91de2 100644 --- a/internal/matrix/types.go +++ b/internal/matrix/types.go @@ -1,9 +1,16 @@ package matrix import ( + accommodationv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/accommodation/v1alpha" + activityv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/activity/v1alpha" + bookv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/book/v1alpha" + networkv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/network/v1alpha" + partnerv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/partner/v1alpha" + pingv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/ping/v1alpha" + transportv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/transport/v1alpha" "github.com/chain4travel/camino-messenger-bot/internal/messaging" "github.com/chain4travel/camino-messenger-bot/internal/metadata" - "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/proto" "maunium.net/go/mautrix/event" ) @@ -15,7 +22,7 @@ type CaminoMatrixMessage struct { Metadata metadata.Metadata `json:"metadata"` } -type ByChunkIndex []CaminoMatrixMessage +type ByChunkIndex []*CaminoMatrixMessage func (b ByChunkIndex) Len() int { return len(b) } func (b ByChunkIndex) Less(i, j int) bool { @@ -26,49 +33,71 @@ func (b ByChunkIndex) Swap(i, j int) { b[i], b[j] = b[j], b[i] } func (m *CaminoMatrixMessage) UnmarshalContent(src []byte) error { switch messaging.MessageType(m.MsgType) { case messaging.ActivityProductListRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.ActivityProductListRequest) + m.Content.RequestContent.ActivityProductListRequest = &activityv1alpha.ActivityProductListRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.ActivityProductListRequest) case messaging.ActivityProductListResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.ActivityProductListResponse) + m.Content.ResponseContent.ActivityProductListResponse = &activityv1alpha.ActivityProductListResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.ActivityProductListResponse) case messaging.ActivitySearchRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.ActivitySearchRequest) + m.Content.RequestContent.ActivitySearchRequest = &activityv1alpha.ActivitySearchRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.ActivitySearchRequest) case messaging.ActivitySearchResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.ActivitySearchResponse) + m.Content.ResponseContent.ActivitySearchResponse = &activityv1alpha.ActivitySearchResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.ActivitySearchResponse) case messaging.AccommodationProductInfoRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.AccommodationProductInfoRequest) + m.Content.RequestContent.AccommodationProductInfoRequest = &accommodationv1alpha.AccommodationProductInfoRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.AccommodationProductInfoRequest) case messaging.AccommodationProductInfoResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.AccommodationProductInfoResponse) + m.Content.ResponseContent.AccommodationProductInfoResponse = &accommodationv1alpha.AccommodationProductInfoResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.AccommodationProductInfoResponse) case messaging.AccommodationProductListRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.AccommodationProductListRequest) + m.Content.RequestContent.AccommodationProductListRequest = &accommodationv1alpha.AccommodationProductListRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.AccommodationProductListRequest) case messaging.AccommodationProductListResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.AccommodationProductListResponse) + m.Content.ResponseContent.AccommodationProductListResponse = &accommodationv1alpha.AccommodationProductListResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.AccommodationProductListResponse) case messaging.AccommodationSearchRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.AccommodationSearchRequest) + m.Content.RequestContent.AccommodationSearchRequest = &accommodationv1alpha.AccommodationSearchRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.AccommodationSearchRequest) case messaging.AccommodationSearchResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.AccommodationSearchResponse) + m.Content.ResponseContent.AccommodationSearchResponse = &accommodationv1alpha.AccommodationSearchResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.AccommodationSearchResponse) case messaging.GetNetworkFeeRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.GetNetworkFeeRequest) + m.Content.RequestContent.GetNetworkFeeRequest = &networkv1alpha.GetNetworkFeeRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.GetNetworkFeeRequest) case messaging.GetNetworkFeeResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.GetNetworkFeeResponse) + m.Content.ResponseContent.GetNetworkFeeResponse = &networkv1alpha.GetNetworkFeeResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.GetNetworkFeeResponse) case messaging.GetPartnerConfigurationRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.GetPartnerConfigurationRequest) + m.Content.RequestContent.GetPartnerConfigurationRequest = &partnerv1alpha.GetPartnerConfigurationRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.GetPartnerConfigurationRequest) case messaging.GetPartnerConfigurationResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.GetPartnerConfigurationResponse) + m.Content.ResponseContent.GetPartnerConfigurationResponse = &partnerv1alpha.GetPartnerConfigurationResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.GetPartnerConfigurationResponse) case messaging.MintRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.MintRequest) + m.Content.RequestContent.MintRequest = &bookv1alpha.MintRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.MintRequest) case messaging.MintResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.MintResponse) + m.Content.ResponseContent.MintResponse = &bookv1alpha.MintResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.MintResponse) case messaging.ValidationRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.ValidationRequest) + m.Content.RequestContent.ValidationRequest = &bookv1alpha.ValidationRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.ValidationRequest) case messaging.ValidationResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.ValidationResponse) + m.Content.ResponseContent.ValidationResponse = &bookv1alpha.ValidationResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.ValidationResponse) case messaging.PingRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.PingRequest) + m.Content.RequestContent.PingRequest = &pingv1alpha.PingRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.PingRequest) case messaging.PingResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.PingResponse) + m.Content.ResponseContent.PingResponse = &pingv1alpha.PingResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.PingResponse) case messaging.TransportSearchRequest: - return proto.Unmarshal(src, &m.Content.RequestContent.TransportSearchRequest) + m.Content.RequestContent.TransportSearchRequest = &transportv1alpha.TransportSearchRequest{} + return proto.Unmarshal(src, m.Content.RequestContent.TransportSearchRequest) case messaging.TransportSearchResponse: - return proto.Unmarshal(src, &m.Content.ResponseContent.TransportSearchResponse) + m.Content.ResponseContent.TransportSearchResponse = &transportv1alpha.TransportSearchResponse{} + return proto.Unmarshal(src, m.Content.ResponseContent.TransportSearchResponse) default: return messaging.ErrUnknownMessageType } diff --git a/internal/messaging/messenger.go b/internal/messaging/messenger.go index d443f777..42f4598c 100644 --- a/internal/messaging/messenger.go +++ b/internal/messaging/messenger.go @@ -1,3 +1,8 @@ +/* + * Copyright (C) 2024, Chain4Travel AG. All rights reserved. + * See the file LICENSE for licensing terms. + */ + package messaging import ( diff --git a/internal/messaging/mock_list_grpc.pb.go b/internal/messaging/mock_list_grpc.pb.go new file mode 100644 index 00000000..c7b77eb3 --- /dev/null +++ b/internal/messaging/mock_list_grpc.pb.go @@ -0,0 +1,62 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/activity/v1alpha/activityv1alphagrpc (interfaces: ActivityProductListServiceClient) +// +// Generated by this command: +// +// mockgen -package=messaging -destination=internal/messaging/mock_list_grpc.pb.go buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/activity/v1alpha/activityv1alphagrpc ActivityProductListServiceClient +// + +// Package messaging is a generated GoMock package. +package messaging + +import ( + context "context" + reflect "reflect" + + activityv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/activity/v1alpha" + gomock "go.uber.org/mock/gomock" + grpc "google.golang.org/grpc" +) + +// MockActivityProductListServiceClient is a mock of ActivityProductListServiceClient interface. +type MockActivityProductListServiceClient struct { + ctrl *gomock.Controller + recorder *MockActivityProductListServiceClientMockRecorder +} + +// MockActivityProductListServiceClientMockRecorder is the mock recorder for MockActivityProductListServiceClient. +type MockActivityProductListServiceClientMockRecorder struct { + mock *MockActivityProductListServiceClient +} + +// NewMockActivityProductListServiceClient creates a new mock instance. +func NewMockActivityProductListServiceClient(ctrl *gomock.Controller) *MockActivityProductListServiceClient { + mock := &MockActivityProductListServiceClient{ctrl: ctrl} + mock.recorder = &MockActivityProductListServiceClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockActivityProductListServiceClient) EXPECT() *MockActivityProductListServiceClientMockRecorder { + return m.recorder +} + +// ActivityProductList mocks base method. +func (m *MockActivityProductListServiceClient) ActivityProductList(arg0 context.Context, arg1 *activityv1alpha.ActivityProductListRequest, arg2 ...grpc.CallOption) (*activityv1alpha.ActivityProductListResponse, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ActivityProductList", varargs...) + ret0, _ := ret[0].(*activityv1alpha.ActivityProductListResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ActivityProductList indicates an expected call of ActivityProductList. +func (mr *MockActivityProductListServiceClientMockRecorder) ActivityProductList(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityProductList", reflect.TypeOf((*MockActivityProductListServiceClient)(nil).ActivityProductList), varargs...) +} diff --git a/internal/messaging/mock_messenger.go b/internal/messaging/mock_messenger.go new file mode 100644 index 00000000..126f9941 --- /dev/null +++ b/internal/messaging/mock_messenger.go @@ -0,0 +1,111 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/internal/messaging (interfaces: Messenger) +// +// Generated by this command: +// +// mockgen -package=messaging -destination=internal/messaging/mock_messenger.go github.com/chain4travel/camino-messenger-bot/internal/messaging Messenger +// + +// Package messaging is a generated GoMock package. +package messaging + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockMessenger is a mock of Messenger interface. +type MockMessenger struct { + ctrl *gomock.Controller + recorder *MockMessengerMockRecorder +} + +// MockMessengerMockRecorder is the mock recorder for MockMessenger. +type MockMessengerMockRecorder struct { + mock *MockMessenger +} + +// NewMockMessenger creates a new mock instance. +func NewMockMessenger(ctrl *gomock.Controller) *MockMessenger { + mock := &MockMessenger{ctrl: ctrl} + mock.recorder = &MockMessengerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMessenger) EXPECT() *MockMessengerMockRecorder { + return m.recorder +} + +// Checkpoint mocks base method. +func (m *MockMessenger) Checkpoint() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Checkpoint") + ret0, _ := ret[0].(string) + return ret0 +} + +// Checkpoint indicates an expected call of Checkpoint. +func (mr *MockMessengerMockRecorder) Checkpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Checkpoint", reflect.TypeOf((*MockMessenger)(nil).Checkpoint)) +} + +// Inbound mocks base method. +func (m *MockMessenger) Inbound() chan Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Inbound") + ret0, _ := ret[0].(chan Message) + return ret0 +} + +// Inbound indicates an expected call of Inbound. +func (mr *MockMessengerMockRecorder) Inbound() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Inbound", reflect.TypeOf((*MockMessenger)(nil).Inbound)) +} + +// SendAsync mocks base method. +func (m *MockMessenger) SendAsync(arg0 context.Context, arg1 Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendAsync", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendAsync indicates an expected call of SendAsync. +func (mr *MockMessengerMockRecorder) SendAsync(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendAsync", reflect.TypeOf((*MockMessenger)(nil).SendAsync), arg0, arg1) +} + +// StartReceiver mocks base method. +func (m *MockMessenger) StartReceiver() (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartReceiver") + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StartReceiver indicates an expected call of StartReceiver. +func (mr *MockMessengerMockRecorder) StartReceiver() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartReceiver", reflect.TypeOf((*MockMessenger)(nil).StartReceiver)) +} + +// StopReceiver mocks base method. +func (m *MockMessenger) StopReceiver() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopReceiver") + ret0, _ := ret[0].(error) + return ret0 +} + +// StopReceiver indicates an expected call of StopReceiver. +func (mr *MockMessengerMockRecorder) StopReceiver() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopReceiver", reflect.TypeOf((*MockMessenger)(nil).StopReceiver)) +} diff --git a/internal/messaging/mock_service_registry.go b/internal/messaging/mock_service_registry.go new file mode 100644 index 00000000..24e8f0d2 --- /dev/null +++ b/internal/messaging/mock_service_registry.go @@ -0,0 +1,68 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/chain4travel/camino-messenger-bot/internal/messaging (interfaces: ServiceRegistry) +// +// Generated by this command: +// +// mockgen -package=messaging -destination=internal/messaging/mock_service_registry.go github.com/chain4travel/camino-messenger-bot/internal/messaging ServiceRegistry +// + +// Package messaging is a generated GoMock package. +package messaging + +import ( + reflect "reflect" + + config "github.com/chain4travel/camino-messenger-bot/config" + client "github.com/chain4travel/camino-messenger-bot/internal/rpc/client" + gomock "go.uber.org/mock/gomock" +) + +// MockServiceRegistry is a mock of ServiceRegistry interface. +type MockServiceRegistry struct { + ctrl *gomock.Controller + recorder *MockServiceRegistryMockRecorder +} + +// MockServiceRegistryMockRecorder is the mock recorder for MockServiceRegistry. +type MockServiceRegistryMockRecorder struct { + mock *MockServiceRegistry +} + +// NewMockServiceRegistry creates a new mock instance. +func NewMockServiceRegistry(ctrl *gomock.Controller) *MockServiceRegistry { + mock := &MockServiceRegistry{ctrl: ctrl} + mock.recorder = &MockServiceRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockServiceRegistry) EXPECT() *MockServiceRegistryMockRecorder { + return m.recorder +} + +// GetService mocks base method. +func (m *MockServiceRegistry) GetService(arg0 MessageType) (Service, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetService", arg0) + ret0, _ := ret[0].(Service) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetService indicates an expected call of GetService. +func (mr *MockServiceRegistryMockRecorder) GetService(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockServiceRegistry)(nil).GetService), arg0) +} + +// RegisterServices mocks base method. +func (m *MockServiceRegistry) RegisterServices(arg0 config.SupportedRequestTypesFlag, arg1 *client.RPCClient) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterServices", arg0, arg1) +} + +// RegisterServices indicates an expected call of RegisterServices. +func (mr *MockServiceRegistryMockRecorder) RegisterServices(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterServices", reflect.TypeOf((*MockServiceRegistry)(nil).RegisterServices), arg0, arg1) +} diff --git a/internal/messaging/noop_response_handler.go b/internal/messaging/noop_response_handler.go index 481bbe50..4967dfde 100644 --- a/internal/messaging/noop_response_handler.go +++ b/internal/messaging/noop_response_handler.go @@ -1,5 +1,5 @@ /* - * Copyright (C) 2022-2023, Chain4Travel AG. All rights reserved. + * Copyright (C) 2024, Chain4Travel AG. All rights reserved. * See the file LICENSE for licensing terms. */ @@ -11,8 +11,7 @@ import ( var _ ResponseHandler = (*NoopResponseHandler)(nil) -type NoopResponseHandler struct { -} +type NoopResponseHandler struct{} -func (n NoopResponseHandler) HandleResponse(context.Context, MessageType, *RequestContent, *ResponseContent) { +func (NoopResponseHandler) HandleResponse(context.Context, MessageType, *RequestContent, *ResponseContent) { } diff --git a/internal/messaging/processor.go b/internal/messaging/processor.go index da180dae..f8080668 100644 --- a/internal/messaging/processor.go +++ b/internal/messaging/processor.go @@ -4,23 +4,19 @@ import ( "context" "errors" "fmt" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "sync" "time" "github.com/chain4travel/camino-messenger-bot/config" "github.com/chain4travel/camino-messenger-bot/internal/metadata" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "google.golang.org/grpc" grpc_metadata "google.golang.org/grpc/metadata" ) -type InvalidMessageError struct { - error -} - var ( _ Processor = (*processor)(nil) @@ -29,20 +25,21 @@ var ( ErrOnlyRequestMessagesAllowed = errors.New("only request messages allowed") ErrUnsupportedRequestType = errors.New("unsupported request type") ErrMissingRecipient = errors.New("missing recipient") + ErrExceededResponseTimeout = errors.New("response exceeded configured timeout") ) type MsgHandler interface { - Request(ctx context.Context, msg Message) (Message, error) - Respond(msg Message) error - Forward(msg Message) + Request(ctx context.Context, msg *Message) (*Message, error) + Respond(msg *Message) error + Forward(msg *Message) } type Processor interface { metadata.Checkpoint MsgHandler SetUserID(userID string) Start(ctx context.Context) - ProcessInbound(message Message) error - ProcessOutbound(ctx context.Context, message Message) (Message, error) + ProcessInbound(message *Message) error + ProcessOutbound(ctx context.Context, message *Message) (*Message, error) } type processor struct { @@ -54,8 +51,8 @@ type processor struct { timeout time.Duration // timeout after which a request is considered failed mu sync.Mutex - responseChannels map[string]chan Message - serviceRegistry *ServiceRegistry + responseChannels map[string]chan *Message + serviceRegistry ServiceRegistry responseHandler ResponseHandler } @@ -63,18 +60,18 @@ func (p *processor) SetUserID(userID string) { p.userID = userID } -func (p *processor) Checkpoint() string { +func (*processor) Checkpoint() string { return "processor" } -func NewProcessor(messenger Messenger, logger *zap.SugaredLogger, cfg config.ProcessorConfig, registry *ServiceRegistry, responseHandler ResponseHandler) Processor { +func NewProcessor(messenger Messenger, logger *zap.SugaredLogger, cfg config.ProcessorConfig, registry ServiceRegistry, responseHandler ResponseHandler) Processor { return &processor{ cfg: cfg, messenger: messenger, logger: logger, tracer: otel.GetTracerProvider().Tracer(""), timeout: time.Duration(cfg.Timeout) * time.Millisecond, // for now applies to all request types - responseChannels: make(map[string]chan Message), + responseChannels: make(map[string]chan *Message), serviceRegistry: registry, responseHandler: responseHandler, } @@ -86,7 +83,7 @@ func (p *processor) Start(ctx context.Context) { case msgEvent := <-p.messenger.Inbound(): p.logger.Debug("Processing msg event of type: ", msgEvent.Type) go func() { - err := p.ProcessInbound(msgEvent) + err := p.ProcessInbound(&msgEvent) if err != nil { p.logger.Warnf("could not process message: %v", err) } @@ -98,7 +95,7 @@ func (p *processor) Start(ctx context.Context) { } } -func (p *processor) ProcessInbound(msg Message) error { +func (p *processor) ProcessInbound(msg *Message) error { if p.userID == "" { return ErrUserIDNotSet } @@ -110,26 +107,25 @@ func (p *processor) ProcessInbound(msg Message) error { p.Forward(msg) return nil default: - return InvalidMessageError{ErrUnknownMessageCategory} + return ErrUnknownMessageCategory } } else { return nil // ignore own outbound messages } } -func (p *processor) ProcessOutbound(ctx context.Context, msg Message) (Message, error) { +func (p *processor) ProcessOutbound(ctx context.Context, msg *Message) (*Message, error) { msg.Metadata.Sender = p.userID if msg.Type.Category() == Request { // only request messages (received by are processed return p.Request(ctx, msg) // forward request msg to matrix - } else { - p.logger.Debugf("Ignoring any non-request message from sender other than: %s ", p.userID) - return Message{}, ErrOnlyRequestMessagesAllowed // ignore msg } + p.logger.Debugf("Ignoring any non-request message from sender other than: %s ", p.userID) + return nil, ErrOnlyRequestMessagesAllowed // ignore msg } -func (p *processor) Request(ctx context.Context, msg Message) (Message, error) { +func (p *processor) Request(ctx context.Context, msg *Message) (*Message, error) { p.logger.Debug("Sending outbound request message") - responseChan := make(chan Message) + responseChan := make(chan *Message) p.mu.Lock() p.responseChannels[msg.Metadata.RequestID] = responseChan p.mu.Unlock() @@ -143,14 +139,14 @@ func (p *processor) Request(ctx context.Context, msg Message) (Message, error) { defer cancel() if msg.Metadata.Recipient == "" { // TODO: add address validation - return Message{}, ErrMissingRecipient + return nil, ErrMissingRecipient } - msg.Metadata.Cheques = nil //TODO issue and attach cheques + msg.Metadata.Cheques = nil // TODO issue and attach cheques ctx, span := p.tracer.Start(ctx, "processor.Request", trace.WithAttributes(attribute.String("type", string(msg.Type)))) defer span.End() - err := p.messenger.SendAsync(ctx, msg) + err := p.messenger.SendAsync(ctx, *msg) if err != nil { - return Message{}, err + return nil, err } ctx, responseSpan := p.tracer.Start(ctx, "processor.AwaitResponse", trace.WithSpanKind(trace.SpanKindConsumer), trace.WithAttributes(attribute.String("type", string(msg.Type)))) defer responseSpan.End() @@ -161,14 +157,13 @@ func (p *processor) Request(ctx context.Context, msg Message) (Message, error) { p.responseHandler.HandleResponse(ctx, msg.Type, &msg.Content.RequestContent, &response.Content.ResponseContent) return response, nil } - //p.logger.Debugf("Ignoring response message with request id: %s, expecting: %s", response.Metadata.RequestID, msg.Metadata.RequestID) case <-ctx.Done(): - return Message{}, fmt.Errorf("response exceeded configured timeout of %v seconds for request: %s", p.timeout, msg.Metadata.RequestID) + return nil, fmt.Errorf("%w of %v seconds for request: %s", ErrExceededResponseTimeout, p.timeout, msg.Metadata.RequestID) } } } -func (p *processor) Respond(msg Message) error { +func (p *processor) Respond(msg *Message) error { traceID, err := trace.TraceIDFromHex(msg.Metadata.RequestID) if err != nil { p.logger.Warnf("failed to parse traceID from hex [requestID:%s]: %v", msg.Metadata.RequestID, err) @@ -179,7 +174,7 @@ func (p *processor) Respond(msg Message) error { var service Service var supported bool if service, supported = p.serviceRegistry.GetService(msg.Type); !supported { - return fmt.Errorf("%v: %s", ErrUnsupportedRequestType, msg.Type) + return fmt.Errorf("%w: %s", ErrUnsupportedRequestType, msg.Type) } md := &msg.Metadata @@ -194,7 +189,7 @@ func (p *processor) Respond(msg Message) error { response, msgType, err := service.Call(ctx, &msg.Content.RequestContent, grpc.Header(&header)) cspan.End() if err != nil { - return err //TODO handle error and return a response message + return err // TODO handle error and return a response message } err = md.FromGrpcMD(header) @@ -202,24 +197,24 @@ func (p *processor) Respond(msg Message) error { p.logger.Infof("error extracting metadata for request: %s", md.RequestID) } - p.responseHandler.HandleResponse(ctx, msgType, &msg.Content.RequestContent, &response) + p.responseHandler.HandleResponse(ctx, msgType, &msg.Content.RequestContent, response) responseMsg := Message{ Type: msgType, Content: MessageContent{ - ResponseContent: response, + ResponseContent: *response, }, Metadata: *md, } return p.messenger.SendAsync(ctx, responseMsg) } -func (p *processor) Forward(msg Message) { +func (p *processor) Forward(msg *Message) { p.logger.Debugf("Forwarding outbound response message: %s", msg.Metadata.RequestID) p.mu.Lock() + defer p.mu.Unlock() responseChan, ok := p.responseChannels[msg.Metadata.RequestID] if ok { responseChan <- msg close(responseChan) } - p.mu.Unlock() } diff --git a/internal/messaging/processor_test.go b/internal/messaging/processor_test.go new file mode 100644 index 00000000..c5d173c7 --- /dev/null +++ b/internal/messaging/processor_test.go @@ -0,0 +1,360 @@ +/* + * Copyright (C) 2024, Chain4Travel AG. All rights reserved. + * See the file LICENSE for licensing terms. + */ + +package messaging + +import ( + "context" + "errors" + "testing" + "time" + + "google.golang.org/grpc" + + "go.uber.org/mock/gomock" + + "github.com/chain4travel/camino-messenger-bot/internal/metadata" + + "github.com/stretchr/testify/require" + + "github.com/chain4travel/camino-messenger-bot/config" + "go.uber.org/zap" +) + +var ( + userID = "userID" + anotherUserID = "anotherUserID" + requestID = "requestID" + errSomeError = errors.New("some error") +) + +func TestProcessInbound(t *testing.T) { + responseMessage := Message{Type: ActivityProductListResponse, Metadata: metadata.Metadata{RequestID: requestID, Sender: anotherUserID}} + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockServiceRegistry := NewMockServiceRegistry(mockCtrl) + mockActivityProductListServiceClient := NewMockActivityProductListServiceClient(mockCtrl) + mockMessenger := NewMockMessenger(mockCtrl) + + type fields struct { + cfg config.ProcessorConfig + messenger Messenger + serviceRegistry ServiceRegistry + responseHandler ResponseHandler + } + type args struct { + msg *Message + } + tests := map[string]struct { + fields fields + args args + prepare func(p *processor) + err error + assert func(t *testing.T, p *processor) + }{ + "err: user id not set": { + fields: fields{ + cfg: config.ProcessorConfig{}, + }, + err: ErrUserIDNotSet, + }, + "err: invalid message type": { + fields: fields{ + cfg: config.ProcessorConfig{}, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + }, + args: args{ + msg: &Message{Type: "invalid", Metadata: metadata.Metadata{Sender: anotherUserID}}, + }, + err: ErrUnknownMessageCategory, + }, + "err: unsupported request message": { + fields: fields{ + cfg: config.ProcessorConfig{}, + serviceRegistry: mockServiceRegistry, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).Return(nil, false) + }, + args: args{ + msg: &Message{Type: ActivitySearchRequest, Metadata: metadata.Metadata{Sender: anotherUserID}}, + }, + err: ErrUnsupportedRequestType, + }, + "ignore own outbound messages": { + fields: fields{ + cfg: config.ProcessorConfig{}, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + }, + args: args{ + msg: &Message{Metadata: metadata.Metadata{Sender: userID}}, + }, + err: nil, // no error, msg will be just ignored + }, + "err: process request message failed": { + fields: fields{ + cfg: config.ProcessorConfig{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + mockActivityProductListServiceClient.EXPECT().ActivityProductList(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil, nil) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).Times(1).Return(activityProductListService{client: mockActivityProductListServiceClient}, true) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any()).Times(1).Return(errSomeError) + }, + args: args{ + msg: &Message{Type: ActivityProductListRequest, Metadata: metadata.Metadata{Sender: anotherUserID}}, + }, + err: errSomeError, + }, + "success: process request message": { + fields: fields{ + cfg: config.ProcessorConfig{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + mockActivityProductListServiceClient.EXPECT().ActivityProductList(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(nil, nil) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).Times(1).Return(activityProductListService{client: mockActivityProductListServiceClient}, true) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + args: args{ + msg: &Message{Type: ActivityProductListRequest, Metadata: metadata.Metadata{Sender: anotherUserID}}, + }, + }, + "success: process response message": { + fields: fields{ + cfg: config.ProcessorConfig{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + prepare: func(p *processor) { + p.responseChannels[requestID] = make(chan *Message, 1) + p.SetUserID(userID) + }, + args: args{ + msg: &responseMessage, + }, + assert: func(t *testing.T, p *processor) { + msgReceived := <-p.responseChannels[requestID] + require.Equal(t, responseMessage, *msgReceived) + }, + }, + } + for tc, tt := range tests { + t.Run(tc, func(t *testing.T) { + p := NewProcessor(tt.fields.messenger, zap.NewNop().Sugar(), tt.fields.cfg, tt.fields.serviceRegistry, tt.fields.responseHandler) + if tt.prepare != nil { + tt.prepare(p.(*processor)) + } + err := p.ProcessInbound(tt.args.msg) + require.ErrorIs(t, err, tt.err) + + if tt.assert != nil { + tt.assert(t, p.(*processor)) + } + }) + } +} + +func TestProcessOutbound(t *testing.T) { + productListResponse := &Message{Type: ActivityProductListResponse, Metadata: metadata.Metadata{RequestID: requestID}} + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockServiceRegistry := NewMockServiceRegistry(mockCtrl) + mockMessenger := NewMockMessenger(mockCtrl) + + type fields struct { + cfg config.ProcessorConfig + messenger Messenger + serviceRegistry ServiceRegistry + responseHandler ResponseHandler + } + type args struct { + msg *Message + } + tests := map[string]struct { + fields fields + args args + want *Message + err error + prepare func(p *processor) + writeResponseToChannel func(p *processor) + }{ + "err: non-request outbound message": { + fields: fields{ + cfg: config.ProcessorConfig{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + args: args{ + msg: &Message{Type: ActivityProductListResponse}, + }, + err: ErrOnlyRequestMessagesAllowed, + }, + "err: missing recipient": { + fields: fields{ + cfg: config.ProcessorConfig{}, + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + args: args{ + msg: &Message{Type: ActivityProductListRequest}, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + }, + err: ErrMissingRecipient, + }, + "err: awaiting-response-timeout exceeded": { + fields: fields{ + cfg: config.ProcessorConfig{Timeout: 10}, // 10ms + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + args: args{ + msg: &Message{Type: ActivityProductListRequest, Metadata: metadata.Metadata{Recipient: anotherUserID}}, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + err: ErrExceededResponseTimeout, + }, + "err: while sending request": { + fields: fields{ + cfg: config.ProcessorConfig{Timeout: 100}, // 10ms + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + args: args{ + msg: &Message{Type: ActivityProductListRequest, Metadata: metadata.Metadata{Recipient: anotherUserID}}, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any()).Times(1).Return(errSomeError) + }, + err: errSomeError, + }, + "success: response before timeout": { + fields: fields{ + cfg: config.ProcessorConfig{Timeout: 500}, // long enough timeout for response to be received + serviceRegistry: mockServiceRegistry, + responseHandler: NoopResponseHandler{}, + messenger: mockMessenger, + }, + args: args{ + msg: &Message{Type: ActivityProductListRequest, Metadata: metadata.Metadata{Recipient: anotherUserID, RequestID: requestID}}, + }, + prepare: func(p *processor) { + p.SetUserID(userID) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any()).Times(1).Return(nil) + }, + writeResponseToChannel: func(p *processor) { + done := func() bool { + p.mu.Lock() + defer p.mu.Unlock() + if _, ok := p.responseChannels[requestID]; ok { + p.responseChannels[requestID] <- productListResponse + return true + } + return false + } + for { + // wait until the response channel is created + if done() { + break + } + } + }, + want: productListResponse, + }, + } + + for tc, tt := range tests { + t.Run(tc, func(t *testing.T) { + p := NewProcessor(tt.fields.messenger, zap.NewNop().Sugar(), tt.fields.cfg, tt.fields.serviceRegistry, tt.fields.responseHandler) + if tt.prepare != nil { + tt.prepare(p.(*processor)) + } + if tt.writeResponseToChannel != nil { + go tt.writeResponseToChannel(p.(*processor)) + } + got, err := p.ProcessOutbound(context.Background(), tt.args.msg) + + require.ErrorIs(t, err, tt.err) + require.Equal(t, tt.want, got) + }) + } +} + +type dummyService struct{} + +func (d dummyService) Call(context.Context, *RequestContent, ...grpc.CallOption) (*ResponseContent, MessageType, error) { + return &ResponseContent{}, "", nil +} + +func TestStart(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockServiceRegistry := NewMockServiceRegistry(mockCtrl) + mockMessenger := NewMockMessenger(mockCtrl) + mockServiceRegistry.EXPECT().GetService(gomock.Any()).AnyTimes().Return(dummyService{}, true) + mockMessenger.EXPECT().SendAsync(gomock.Any(), gomock.Any()).Times(2).Return(nil) + + t.Run("start processor and accept messages", func(t *testing.T) { + cfg := config.ProcessorConfig{} + serviceRegistry := mockServiceRegistry + responseHandler := NoopResponseHandler{} + messenger := mockMessenger + + ch := make(chan Message, 5) + // incoming messages + { + // msg without sender + ch <- Message{Metadata: metadata.Metadata{}} + // msg with sender == userID + ch <- Message{Metadata: metadata.Metadata{Sender: userID}} + // msg with sender == userID but without valid msgType + ch <- Message{Metadata: metadata.Metadata{Sender: anotherUserID}} + // msg with sender == userID and valid msgType + ch <- Message{ + Type: ActivityProductListRequest, + Metadata: metadata.Metadata{Sender: anotherUserID}, + } + // 2nd msg with sender == userID and valid msgType + ch <- Message{ + Type: ActivitySearchRequest, + Metadata: metadata.Metadata{Sender: anotherUserID}, + } + } + // mocks + mockMessenger.EXPECT().Inbound().AnyTimes().Return(ch) + + ctx, cancel := context.WithCancel(context.Background()) + p := NewProcessor(messenger, zap.NewNop().Sugar(), cfg, serviceRegistry, responseHandler) + p.SetUserID(userID) + go p.Start(ctx) + + time.Sleep(1 * time.Second) + cancel() + }) +} diff --git a/internal/messaging/response_handler.go b/internal/messaging/response_handler.go index 4fcd8084..f43d3a5e 100644 --- a/internal/messaging/response_handler.go +++ b/internal/messaging/response_handler.go @@ -6,10 +6,13 @@ package messaging import ( - typesv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/types/v1alpha" "context" "errors" "fmt" + "strconv" + + typesv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/types/v1alpha" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/hypersdk/chain" "github.com/ava-labs/hypersdk/codec" @@ -17,7 +20,6 @@ import ( "github.com/chain4travel/caminotravelvm/actions" "github.com/chain4travel/caminotravelvm/consts" "go.uber.org/zap" - "strconv" ) var _ ResponseHandler = (*TvmResponseHandler)(nil) @@ -33,67 +35,81 @@ type TvmResponseHandler struct { func (h *TvmResponseHandler) HandleResponse(ctx context.Context, msgType MessageType, request *RequestContent, response *ResponseContent) { switch msgType { case MintRequest: // distributor will post-process a mint request to buy the returned NFT - if response.MintResponse.Header == nil { - response.MintResponse.Header = &typesv1alpha.ResponseHeader{} - } - if response.MintTransactionId == "" { - addErrorToResponseHeader(response, "missing mint transaction id") + if h.handleMintRequest(ctx, response) { return } - mintID, err := ids.FromString(response.MintTransactionId) - if err != nil { - addErrorToResponseHeader(response, fmt.Sprintf("error parsing mint transaction id: %v", err)) + case MintResponse: // provider will act upon receiving a mint response by minting an NFT + if h.handleMintResponse(ctx, response, request) { return } + } +} - success, txID, err := h.tvmClient.SendTxAndWait(ctx, transferNFTAction(h.tvmClient.Address(), mintID)) - if err != nil { - errMessage := fmt.Sprintf("error buying NFT: %v", err) - if errors.Is(err, context.DeadlineExceeded) { - errMessage = fmt.Sprintf("%v: %v", tvm.ErrAwaitTxConfirmationTimeout, h.tvmClient.Timeout) - } - addErrorToResponseHeader(response, errMessage) - return - } - if !success { - addErrorToResponseHeader(response, "buying NFT failed") - return +func (h *TvmResponseHandler) handleMintResponse(ctx context.Context, response *ResponseContent, request *RequestContent) bool { + owner := h.tvmClient.Address() + if response.MintResponse.Header == nil { + response.MintResponse.Header = &typesv1alpha.ResponseHeader{} + } + buyer, err := codec.ParseAddressBech32(consts.HRP, request.MintRequest.BuyerAddress) + if err != nil { + addErrorToResponseHeader(response, fmt.Sprintf("error parsing buyer address: %v", err)) + return true + } + price, err := strconv.Atoi(response.MintResponse.Price.Value) + if err != nil { + addErrorToResponseHeader(response, fmt.Sprintf("error parsing price value: %v", err)) + return true + } + success, txID, err := h.tvmClient.SendTxAndWait(ctx, createNFTAction(owner, buyer, uint64(response.MintResponse.BuyableUntil.Seconds), uint64(price), response.MintResponse.MintId)) + if err != nil { + errMessage := fmt.Sprintf("error minting NFT: %v", err) + if errors.Is(err, context.DeadlineExceeded) { + errMessage = fmt.Sprintf("%v: %v", tvm.ErrAwaitTxConfirmationTimeout, h.tvmClient.Timeout) } + addErrorToResponseHeader(response, errMessage) + return true + } + if !success { + addErrorToResponseHeader(response, "minting NFT tx failed") + return true + } + h.logger.Infof("NFT minted with txID: %s\n", txID) + response.MintResponse.Header.Status = typesv1alpha.StatusType_STATUS_TYPE_SUCCESS + response.MintTransactionId = txID.String() + return false +} - h.logger.Infof("Bought NFT (txID=%s) with ID: %s\n", txID, mintID) - response.BuyTransactionId = txID.String() - case MintResponse: // provider will act upon receiving a mint response by minting an NFT - owner := h.tvmClient.Address() - if response.MintResponse.Header == nil { - response.MintResponse.Header = &typesv1alpha.ResponseHeader{} - } - buyer, err := codec.ParseAddressBech32(consts.HRP, request.MintRequest.BuyerAddress) - if err != nil { - addErrorToResponseHeader(response, fmt.Sprintf("error parsing buyer address: %v", err)) - return - } - price, err := strconv.Atoi(response.MintResponse.Price.Value) - if err != nil { - addErrorToResponseHeader(response, fmt.Sprintf("error parsing price value: %v", err)) - return - } - success, txID, err := h.tvmClient.SendTxAndWait(ctx, createNFTAction(owner, buyer, uint64(response.MintResponse.BuyableUntil.Seconds), uint64(price), response.MintResponse.MintId)) - if err != nil { - errMessage := fmt.Sprintf("error minting NFT: %v", err) - if errors.Is(err, context.DeadlineExceeded) { - errMessage = fmt.Sprintf("%v: %v", tvm.ErrAwaitTxConfirmationTimeout, h.tvmClient.Timeout) - } - addErrorToResponseHeader(response, errMessage) - return - } - if !success { - addErrorToResponseHeader(response, "minting NFT tx failed") - return +func (h *TvmResponseHandler) handleMintRequest(ctx context.Context, response *ResponseContent) bool { + if response.MintResponse.Header == nil { + response.MintResponse.Header = &typesv1alpha.ResponseHeader{} + } + if response.MintTransactionId == "" { + addErrorToResponseHeader(response, "missing mint transaction id") + return true + } + mintID, err := ids.FromString(response.MintTransactionId) + if err != nil { + addErrorToResponseHeader(response, fmt.Sprintf("error parsing mint transaction id: %v", err)) + return true + } + + success, txID, err := h.tvmClient.SendTxAndWait(ctx, transferNFTAction(h.tvmClient.Address(), mintID)) + if err != nil { + errMessage := fmt.Sprintf("error buying NFT: %v", err) + if errors.Is(err, context.DeadlineExceeded) { + errMessage = fmt.Sprintf("%v: %v", tvm.ErrAwaitTxConfirmationTimeout, h.tvmClient.Timeout) } - h.logger.Infof("NFT minted with txID: %s\n", txID) - response.MintResponse.Header.Status = typesv1alpha.StatusType_STATUS_TYPE_SUCCESS - response.MintTransactionId = txID.String() + addErrorToResponseHeader(response, errMessage) + return true } + if !success { + addErrorToResponseHeader(response, "buying NFT failed") + return true + } + + h.logger.Infof("Bought NFT (txID=%s) with ID: %s\n", txID, mintID) + response.BuyTransactionId = txID.String() + return false } func addErrorToResponseHeader(response *ResponseContent, errMessage string) { @@ -107,6 +123,7 @@ func addErrorToResponseHeader(response *ResponseContent, errMessage string) { func NewResponseHandler(tvmClient *tvm.Client, logger *zap.SugaredLogger) *TvmResponseHandler { return &TvmResponseHandler{tvmClient: tvmClient, logger: logger} } + func createNFTAction(owner, buyer codec.Address, purchaseExpiration, price uint64, metadata string) chain.Action { return &actions.CreateNFT{ Owner: owner, @@ -119,6 +136,7 @@ func createNFTAction(owner, buyer codec.Address, purchaseExpiration, price uint6 Metadata: []byte(metadata), } } + func transferNFTAction(newOwner codec.Address, nftID ids.ID) chain.Action { return &actions.TransferNFT{ To: newOwner, diff --git a/internal/messaging/service.go b/internal/messaging/service.go index 4da43583..04abaf8a 100644 --- a/internal/messaging/service.go +++ b/internal/messaging/service.go @@ -6,15 +6,16 @@ package messaging import ( + "context" + "fmt" + "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/activity/v1alpha/activityv1alphagrpc" "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/book/v1alpha/bookv1alphagrpc" "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/transport/v1alpha/transportv1alphagrpc" networkv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/network/v1alpha" partnerv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/partner/v1alpha" pingv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/ping/v1alpha" - "context" - "errors" - "fmt" + "github.com/chain4travel/camino-messenger-bot/internal/metadata" "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/accommodation/v1alpha/accommodationv1alphagrpc" @@ -22,181 +23,143 @@ import ( ) var ( - _ Service = (*activityProductListService)(nil) - _ Service = (*activityService)(nil) - _ Service = (*accommodationProductInfoService)(nil) - _ Service = (*accommodationProductListService)(nil) - _ Service = (*accommodationService)(nil) - _ Service = (*mintService)(nil) - _ Service = (*validationService)(nil) - _ Service = (*networkService)(nil) - _ Service = (*partnerService)(nil) - _ Service = (*pingService)(nil) - _ Service = (*transportService)(nil) - ErrUnknownMessageType = errors.New("unknown message type") + _ Service = (*activityProductListService)(nil) + _ Service = (*activityService)(nil) + _ Service = (*accommodationProductInfoService)(nil) + _ Service = (*accommodationProductListService)(nil) + _ Service = (*accommodationService)(nil) + _ Service = (*mintService)(nil) + _ Service = (*validationService)(nil) + _ Service = (*networkService)(nil) + _ Service = (*partnerService)(nil) + _ Service = (*pingService)(nil) + _ Service = (*transportService)(nil) ) type Service interface { - Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) + Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) } type activityProductListService struct { - client *activityv1alphagrpc.ActivityProductListServiceClient + client activityv1alphagrpc.ActivityProductListServiceClient } -func (a activityProductListService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.ActivityProductListRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*a.client).ActivityProductList(ctx, &request.ActivityProductListRequest, opts...) +func (a activityProductListService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := a.client.ActivityProductList(ctx, request.ActivityProductListRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.ActivityProductListResponse = *response // otherwise nil pointer dereference + responseContent.ActivityProductListResponse = response // otherwise nil pointer dereference } - return responseContent, ActivityProductListResponse, err + return &responseContent, ActivityProductListResponse, err } type activityService struct { client *activityv1alphagrpc.ActivitySearchServiceClient } -func (s activityService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.ActivitySearchRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*s.client).ActivitySearch(ctx, &request.ActivitySearchRequest, opts...) +func (s activityService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*s.client).ActivitySearch(ctx, request.ActivitySearchRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.ActivitySearchResponse = *response // otherwise nil pointer dereference + responseContent.ActivitySearchResponse = response // otherwise nil pointer dereference } - return responseContent, ActivitySearchResponse, err + return &responseContent, ActivitySearchResponse, err } type accommodationProductInfoService struct { client *accommodationv1alphagrpc.AccommodationProductInfoServiceClient } -func (a accommodationProductInfoService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.AccommodationProductInfoRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*a.client).AccommodationProductInfo(ctx, &request.AccommodationProductInfoRequest, opts...) +func (a accommodationProductInfoService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*a.client).AccommodationProductInfo(ctx, request.AccommodationProductInfoRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.AccommodationProductInfoResponse = *response // otherwise nil pointer dereference + responseContent.AccommodationProductInfoResponse = response // otherwise nil pointer dereference } - return responseContent, AccommodationProductInfoResponse, err + return &responseContent, AccommodationProductInfoResponse, err } type accommodationProductListService struct { client *accommodationv1alphagrpc.AccommodationProductListServiceClient } -func (a accommodationProductListService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.AccommodationProductListRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*a.client).AccommodationProductList(ctx, &request.AccommodationProductListRequest, opts...) +func (a accommodationProductListService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*a.client).AccommodationProductList(ctx, request.AccommodationProductListRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.AccommodationProductListResponse = *response // otherwise nil pointer dereference + responseContent.AccommodationProductListResponse = response // otherwise nil pointer dereference } - return responseContent, AccommodationProductListResponse, err + return &responseContent, AccommodationProductListResponse, err } type accommodationService struct { client *accommodationv1alphagrpc.AccommodationSearchServiceClient } -func (s accommodationService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.AccommodationSearchRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*s.client).AccommodationSearch(ctx, &request.AccommodationSearchRequest, opts...) +func (s accommodationService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*s.client).AccommodationSearch(ctx, request.AccommodationSearchRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.AccommodationSearchResponse = *response // otherwise nil pointer dereference + responseContent.AccommodationSearchResponse = response // otherwise nil pointer dereference } - return responseContent, AccommodationSearchResponse, err + return &responseContent, AccommodationSearchResponse, err } type mintService struct { client *bookv1alphagrpc.MintServiceClient } -func (m mintService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - - if &request.MintRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*m.client).Mint(ctx, &request.MintRequest, opts...) +func (m mintService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*m.client).Mint(ctx, request.MintRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.MintResponse = *response // otherwise nil pointer dereference + responseContent.MintResponse = response // otherwise nil pointer dereference } - return responseContent, MintResponse, err + return &responseContent, MintResponse, err } type validationService struct { client *bookv1alphagrpc.ValidationServiceClient } -func (v validationService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.ValidationRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*v.client).Validation(ctx, &request.ValidationRequest, opts...) +func (v validationService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*v.client).Validation(ctx, request.ValidationRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.ValidationResponse = *response // otherwise nil pointer dereference + responseContent.ValidationResponse = response // otherwise nil pointer dereference } - return responseContent, ValidationResponse, err -} - -type networkService struct { + return &responseContent, ValidationResponse, err } -func (s networkService) Call(_ context.Context, request *RequestContent, _ ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.GetNetworkFeeRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } +type networkService struct{} - return ResponseContent{ - GetNetworkFeeResponse: networkv1alpha.GetNetworkFeeResponse{ - NetworkFee: &networkv1alpha.NetworkFee{Amount: 100000}, //TODO implement +func (s networkService) Call(_ context.Context, _ *RequestContent, _ ...grpc.CallOption) (*ResponseContent, MessageType, error) { + return &ResponseContent{ + GetNetworkFeeResponse: &networkv1alpha.GetNetworkFeeResponse{ + NetworkFee: &networkv1alpha.NetworkFee{Amount: 100000}, // TODO implement }, }, GetNetworkFeeResponse, nil } -type partnerService struct { -} - -func (s partnerService) Call(_ context.Context, request *RequestContent, _ ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.GetPartnerConfigurationRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } +type partnerService struct{} - return ResponseContent{ - GetPartnerConfigurationResponse: partnerv1alpha.GetPartnerConfigurationResponse{ - PartnerConfiguration: nil, //TODO implement +func (s partnerService) Call(_ context.Context, _ *RequestContent, _ ...grpc.CallOption) (*ResponseContent, MessageType, error) { + return &ResponseContent{ + GetPartnerConfigurationResponse: &partnerv1alpha.GetPartnerConfigurationResponse{ + PartnerConfiguration: nil, // TODO implement CurrentBlockHeight: 0, }, }, GetPartnerConfigurationResponse, nil } -type pingService struct { -} - -func (s pingService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.PingRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } +type pingService struct{} +func (s pingService) Call(ctx context.Context, request *RequestContent, _ ...grpc.CallOption) (*ResponseContent, MessageType, error) { md := metadata.Metadata{} err := md.ExtractMetadata(ctx) if err != nil { - return ResponseContent{}, PingResponse, err + return nil, PingResponse, err } - return ResponseContent{PingResponse: pingv1alpha.PingResponse{ + return &ResponseContent{PingResponse: &pingv1alpha.PingResponse{ Header: nil, PingMessage: fmt.Sprintf("Ping response to [%s] with request ID: %s", request.PingMessage, md.RequestID), Timestamp: nil, @@ -207,14 +170,11 @@ type transportService struct { client *transportv1alphagrpc.TransportSearchServiceClient } -func (s transportService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (ResponseContent, MessageType, error) { - if &request.TransportSearchRequest == nil { - return ResponseContent{}, "", ErrUnknownMessageType - } - response, err := (*s.client).TransportSearch(ctx, &request.TransportSearchRequest, opts...) +func (s transportService) Call(ctx context.Context, request *RequestContent, opts ...grpc.CallOption) (*ResponseContent, MessageType, error) { + response, err := (*s.client).TransportSearch(ctx, request.TransportSearchRequest, opts...) responseContent := ResponseContent{} if err == nil { - responseContent.TransportSearchResponse = *response // otherwise nil pointer dereference + responseContent.TransportSearchResponse = response // otherwise nil pointer dereference } - return responseContent, TransportSearchResponse, err + return &responseContent, TransportSearchResponse, err } diff --git a/internal/messaging/service_registry.go b/internal/messaging/service_registry.go index 05fe0a0c..bd0b7c79 100644 --- a/internal/messaging/service_registry.go +++ b/internal/messaging/service_registry.go @@ -6,6 +6,8 @@ package messaging import ( + "sync" + "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/accommodation/v1alpha/accommodationv1alphagrpc" "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/activity/v1alpha/activityv1alphagrpc" "buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/book/v1alpha/bookv1alphagrpc" @@ -13,24 +15,27 @@ import ( "github.com/chain4travel/camino-messenger-bot/config" "github.com/chain4travel/camino-messenger-bot/internal/rpc/client" "go.uber.org/zap" - "sync" ) -type ServiceRegistry struct { +type ServiceRegistry interface { + RegisterServices(requestTypes config.SupportedRequestTypesFlag, rpcClient *client.RPCClient) + GetService(messageType MessageType) (Service, bool) +} +type serviceRegistry struct { logger *zap.SugaredLogger services map[MessageType]Service - lock sync.RWMutex + lock *sync.RWMutex } -func NewServiceRegistry(logger *zap.SugaredLogger) *ServiceRegistry { - return &ServiceRegistry{ +func NewServiceRegistry(logger *zap.SugaredLogger) ServiceRegistry { + return &serviceRegistry{ logger: logger, services: make(map[MessageType]Service), - lock: sync.RWMutex{}, + lock: &sync.RWMutex{}, } } -func (s *ServiceRegistry) RegisterServices(requestTypes config.SupportedRequestTypesFlag, rpcClient *client.RPCClient) { +func (s *serviceRegistry) RegisterServices(requestTypes config.SupportedRequestTypesFlag, rpcClient *client.RPCClient) { s.lock.Lock() defer s.lock.Unlock() @@ -39,7 +44,7 @@ func (s *ServiceRegistry) RegisterServices(requestTypes config.SupportedRequestT switch MessageType(requestType) { case ActivityProductListRequest: c := activityv1alphagrpc.NewActivityProductListServiceClient(rpcClient.ClientConn) - service = activityProductListService{client: &c} + service = activityProductListService{client: c} case ActivitySearchRequest: c := activityv1alphagrpc.NewActivitySearchServiceClient(rpcClient.ClientConn) service = activityService{client: &c} @@ -75,7 +80,7 @@ func (s *ServiceRegistry) RegisterServices(requestTypes config.SupportedRequestT } } -func (s *ServiceRegistry) GetService(messageType MessageType) (Service, bool) { +func (s *serviceRegistry) GetService(messageType MessageType) (Service, bool) { service, ok := s.services[messageType] return service, ok } diff --git a/internal/messaging/types.go b/internal/messaging/types.go index 7fb2a329..28d02db0 100644 --- a/internal/messaging/types.go +++ b/internal/messaging/types.go @@ -1,6 +1,8 @@ package messaging import ( + "errors" + accommodationv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/accommodation/v1alpha" activityv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/activity/v1alpha" bookv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/book/v1alpha" @@ -8,36 +10,42 @@ import ( partnerv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/partner/v1alpha" pingv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/ping/v1alpha" transportv1alpha "buf.build/gen/go/chain4travel/camino-messenger-protocol/protocolbuffers/go/cmp/services/transport/v1alpha" + "github.com/chain4travel/camino-messenger-bot/internal/metadata" - "github.com/golang/protobuf/proto" + + "google.golang.org/protobuf/proto" ) +var ErrUnknownMessageType = errors.New("unknown message type") + type RequestContent struct { - activityv1alpha.ActivityProductListRequest - activityv1alpha.ActivitySearchRequest - accommodationv1alpha.AccommodationProductInfoRequest - accommodationv1alpha.AccommodationProductListRequest - accommodationv1alpha.AccommodationSearchRequest - networkv1alpha.GetNetworkFeeRequest - partnerv1alpha.GetPartnerConfigurationRequest - bookv1alpha.MintRequest - bookv1alpha.ValidationRequest - pingv1alpha.PingRequest - transportv1alpha.TransportSearchRequest + *activityv1alpha.ActivityProductListRequest + *activityv1alpha.ActivitySearchRequest + *accommodationv1alpha.AccommodationProductInfoRequest + *accommodationv1alpha.AccommodationProductListRequest + *accommodationv1alpha.AccommodationSearchRequest + *networkv1alpha.GetNetworkFeeRequest + *partnerv1alpha.GetPartnerConfigurationRequest + *bookv1alpha.MintRequest + *bookv1alpha.ValidationRequest + *pingv1alpha.PingRequest + *transportv1alpha.TransportSearchRequest } + type ResponseContent struct { - activityv1alpha.ActivityProductListResponse - activityv1alpha.ActivitySearchResponse - accommodationv1alpha.AccommodationProductInfoResponse - accommodationv1alpha.AccommodationProductListResponse - accommodationv1alpha.AccommodationSearchResponse - networkv1alpha.GetNetworkFeeResponse - partnerv1alpha.GetPartnerConfigurationResponse - bookv1alpha.MintResponse - bookv1alpha.ValidationResponse - pingv1alpha.PingResponse - transportv1alpha.TransportSearchResponse + *activityv1alpha.ActivityProductListResponse + *activityv1alpha.ActivitySearchResponse + *accommodationv1alpha.AccommodationProductInfoResponse + *accommodationv1alpha.AccommodationProductListResponse + *accommodationv1alpha.AccommodationSearchResponse + *networkv1alpha.GetNetworkFeeResponse + *partnerv1alpha.GetPartnerConfigurationResponse + *bookv1alpha.MintResponse + *bookv1alpha.ValidationResponse + *pingv1alpha.PingResponse + *transportv1alpha.TransportSearchResponse } + type MessageContent struct { RequestContent ResponseContent @@ -50,8 +58,10 @@ type Message struct { Metadata metadata.Metadata `json:"metadata"` } -type MessageCategory byte -type MessageType string +type ( + MessageCategory byte + MessageType string +) const ( // message categories @@ -114,52 +124,51 @@ func (mt MessageType) Category() MessageCategory { } func (m *Message) MarshalContent() ([]byte, error) { - switch m.Type { case ActivityProductListRequest: - return proto.Marshal(&m.Content.ActivityProductListRequest) + return proto.Marshal(m.Content.ActivityProductListRequest) case ActivityProductListResponse: - return proto.Marshal(&m.Content.ActivityProductListResponse) + return proto.Marshal(m.Content.ActivityProductListResponse) case ActivitySearchRequest: - return proto.Marshal(&m.Content.ActivitySearchRequest) + return proto.Marshal(m.Content.ActivitySearchRequest) case ActivitySearchResponse: - return proto.Marshal(&m.Content.ActivitySearchResponse) + return proto.Marshal(m.Content.ActivitySearchResponse) case AccommodationProductInfoRequest: - return proto.Marshal(&m.Content.AccommodationProductInfoRequest) + return proto.Marshal(m.Content.AccommodationProductInfoRequest) case AccommodationProductInfoResponse: - return proto.Marshal(&m.Content.AccommodationProductInfoResponse) + return proto.Marshal(m.Content.AccommodationProductInfoResponse) case AccommodationProductListRequest: - return proto.Marshal(&m.Content.AccommodationProductListRequest) + return proto.Marshal(m.Content.AccommodationProductListRequest) case AccommodationProductListResponse: - return proto.Marshal(&m.Content.AccommodationProductListResponse) + return proto.Marshal(m.Content.AccommodationProductListResponse) case AccommodationSearchRequest: - return proto.Marshal(&m.Content.AccommodationSearchRequest) + return proto.Marshal(m.Content.AccommodationSearchRequest) case AccommodationSearchResponse: - return proto.Marshal(&m.Content.AccommodationSearchResponse) + return proto.Marshal(m.Content.AccommodationSearchResponse) case GetNetworkFeeRequest: - return proto.Marshal(&m.Content.GetNetworkFeeRequest) + return proto.Marshal(m.Content.GetNetworkFeeRequest) case GetNetworkFeeResponse: - return proto.Marshal(&m.Content.GetNetworkFeeResponse) + return proto.Marshal(m.Content.GetNetworkFeeResponse) case GetPartnerConfigurationRequest: - return proto.Marshal(&m.Content.GetPartnerConfigurationRequest) + return proto.Marshal(m.Content.GetPartnerConfigurationRequest) case GetPartnerConfigurationResponse: - return proto.Marshal(&m.Content.GetPartnerConfigurationResponse) + return proto.Marshal(m.Content.GetPartnerConfigurationResponse) case MintRequest: - return proto.Marshal(&m.Content.MintRequest) + return proto.Marshal(m.Content.MintRequest) case MintResponse: - return proto.Marshal(&m.Content.MintResponse) + return proto.Marshal(m.Content.MintResponse) case ValidationRequest: - return proto.Marshal(&m.Content.ValidationRequest) + return proto.Marshal(m.Content.ValidationRequest) case ValidationResponse: - return proto.Marshal(&m.Content.ValidationResponse) + return proto.Marshal(m.Content.ValidationResponse) case PingRequest: - return proto.Marshal(&m.Content.PingRequest) + return proto.Marshal(m.Content.PingRequest) case PingResponse: - return proto.Marshal(&m.Content.PingResponse) + return proto.Marshal(m.Content.PingResponse) case TransportSearchRequest: - return proto.Marshal(&m.Content.TransportSearchRequest) + return proto.Marshal(m.Content.TransportSearchRequest) case TransportSearchResponse: - return proto.Marshal(&m.Content.TransportSearchResponse) + return proto.Marshal(m.Content.TransportSearchResponse) default: return nil, ErrUnknownMessageType } diff --git a/internal/metadata/metadata.go b/internal/metadata/metadata.go index 3312b93e..e25ceb56 100644 --- a/internal/metadata/metadata.go +++ b/internal/metadata/metadata.go @@ -51,14 +51,14 @@ func (m *Metadata) FromGrpcMD(mdPairs metadata.MD) error { if cheques, found := mdPairs["cheques"]; found { chequesJSON := strings.Join(cheques, "") if err := json.Unmarshal([]byte(chequesJSON), &m.Cheques); err != nil { - return fmt.Errorf("error unmarshalling cheques: %v", err) + return fmt.Errorf("error unmarshalling cheques: %w", err) } } if timestamps, found := mdPairs["timestamps"]; found { timestampsJSON := strings.Join(timestamps, "") if err := json.Unmarshal([]byte(timestampsJSON), &m.Timestamps); err != nil { - return fmt.Errorf("error unmarshalling timestamps: %v", err) + return fmt.Errorf("error unmarshalling timestamps: %w", err) } } if providerOperator, found := mdPairs["provider_operator"]; found { @@ -66,6 +66,7 @@ func (m *Metadata) FromGrpcMD(mdPairs metadata.MD) error { } return nil } + func (m *Metadata) ToGrpcMD() metadata.MD { md := metadata.New(map[string]string{ "request_id": m.RequestID, @@ -83,6 +84,7 @@ func (m *Metadata) ToGrpcMD() metadata.MD { }) return md } + func (m *Metadata) Stamp(checkpoint string) { if m.Timestamps == nil { m.Timestamps = make(map[string]int64) @@ -90,6 +92,7 @@ func (m *Metadata) Stamp(checkpoint string) { idx := len(m.Timestamps) // for analysis' sake, we want to know the order of the checkpoints m.Timestamps[fmt.Sprintf("%d-%s", idx, checkpoint)] = time.Now().UnixMilli() } + func (m *Metadata) StampOn(checkpoint string, t int64) { if m.Timestamps == nil { m.Timestamps = make(map[string]int64) diff --git a/internal/rpc/client/client.go b/internal/rpc/client/client.go index c64cd794..f879ee1b 100644 --- a/internal/rpc/client/client.go +++ b/internal/rpc/client/client.go @@ -2,9 +2,10 @@ package client import ( "fmt" - "log" "sync" + "google.golang.org/grpc/credentials/insecure" + "github.com/chain4travel/camino-messenger-bot/config" "github.com/chain4travel/camino-messenger-bot/internal/metadata" utils "github.com/chain4travel/camino-messenger-bot/utils/tls" @@ -38,11 +39,11 @@ func (rc *RPCClient) Start() error { var opts []grpc.DialOption if rc.cfg.Unencrypted { - opts = append(opts, grpc.WithInsecure()) + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { tlsCreds, err := utils.LoadCATLSCredentials(rc.cfg.CACertFile) if err != nil { - log.Fatalf("could not load TLS keys: %s", err) + return fmt.Errorf("could not load TLS keys: %w", err) } opts = append(opts, grpc.WithTransportCredentials(tlsCreds)) } diff --git a/internal/rpc/server/server.go b/internal/rpc/server/server.go index b4439830..ba8b0d14 100644 --- a/internal/rpc/server/server.go +++ b/internal/rpc/server/server.go @@ -2,9 +2,7 @@ package server import ( "context" - "errors" "fmt" - "log" "net" "github.com/chain4travel/camino-messenger-bot/internal/tracing" @@ -48,13 +46,11 @@ var ( _ bookv1alphagrpc.ValidationServiceServer = (*server)(nil) _ pingv1alphagrpc.PingServiceServer = (*server)(nil) _ transportv1alphagrpc.TransportSearchServiceServer = (*server)(nil) - - errMissingRecipient = errors.New("missing recipient") ) type Server interface { metadata.Checkpoint - Start() + Start() error Stop() } type server struct { @@ -63,14 +59,14 @@ type server struct { logger *zap.SugaredLogger tracer tracing.Tracer processor messaging.Processor - serviceRegistry *messaging.ServiceRegistry + serviceRegistry messaging.ServiceRegistry } -func (s *server) Checkpoint() string { +func (*server) Checkpoint() string { return "request-gateway" } -func NewServer(cfg *config.RPCServerConfig, logger *zap.SugaredLogger, tracer tracing.Tracer, processor messaging.Processor, serviceRegistry *messaging.ServiceRegistry) *server { +func NewServer(cfg *config.RPCServerConfig, logger *zap.SugaredLogger, tracer tracing.Tracer, processor messaging.Processor, serviceRegistry messaging.ServiceRegistry) Server { var opts []grpc.ServerOption if cfg.Unencrypted { logger.Warn("Running gRPC server without TLS!") @@ -102,12 +98,12 @@ func createGrpcServerAndRegisterServices(server *server, opts ...grpc.ServerOpti return grpcServer } -func (s *server) Start() { +func (s *server) Start() error { lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.cfg.Port)) if err != nil { - log.Fatalf("failed to listen: %v", err) + return fmt.Errorf("failed to listen: %w", err) } - s.grpcServer.Serve(lis) + return s.grpcServer.Serve(lis) } func (s *server) Stop() { @@ -116,77 +112,77 @@ func (s *server) Stop() { } func (s *server) AccommodationProductInfo(ctx context.Context, request *accommodationv1alpha.AccommodationProductInfoRequest) (*accommodationv1alpha.AccommodationProductInfoResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.AccommodationProductInfoRequest, &messaging.RequestContent{AccommodationProductInfoRequest: *request}) - return &response.AccommodationProductInfoResponse, err + response, err := s.processExternalRequest(ctx, messaging.AccommodationProductInfoRequest, &messaging.RequestContent{AccommodationProductInfoRequest: request}) + return response.AccommodationProductInfoResponse, err } func (s *server) AccommodationProductList(ctx context.Context, request *accommodationv1alpha.AccommodationProductListRequest) (*accommodationv1alpha.AccommodationProductListResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.AccommodationProductListRequest, &messaging.RequestContent{AccommodationProductListRequest: *request}) - return &response.AccommodationProductListResponse, err + response, err := s.processExternalRequest(ctx, messaging.AccommodationProductListRequest, &messaging.RequestContent{AccommodationProductListRequest: request}) + return response.AccommodationProductListResponse, err } func (s *server) AccommodationSearch(ctx context.Context, request *accommodationv1alpha.AccommodationSearchRequest) (*accommodationv1alpha.AccommodationSearchResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.AccommodationSearchRequest, &messaging.RequestContent{AccommodationSearchRequest: *request}) - return &response.AccommodationSearchResponse, err //TODO set specific errors according to https://grpc.github.io/grpc/core/md_doc_statuscodes.html ? + response, err := s.processExternalRequest(ctx, messaging.AccommodationSearchRequest, &messaging.RequestContent{AccommodationSearchRequest: request}) + return response.AccommodationSearchResponse, err // TODO set specific errors according to https://grpc.github.io/grpc/core/md_doc_statuscodes.html ? } func (s *server) Ping(ctx context.Context, request *pingv1alpha.PingRequest) (*pingv1alpha.PingResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.PingRequest, &messaging.RequestContent{PingRequest: *request}) - return &response.PingResponse, err + response, err := s.processExternalRequest(ctx, messaging.PingRequest, &messaging.RequestContent{PingRequest: request}) + return response.PingResponse, err } func (s *server) GetNetworkFee(ctx context.Context, request *networkv1alpha.GetNetworkFeeRequest) (*networkv1alpha.GetNetworkFeeResponse, error) { - response, err := s.processInternalRequest(ctx, messaging.GetNetworkFeeRequest, &messaging.RequestContent{GetNetworkFeeRequest: *request}) - return &response.GetNetworkFeeResponse, err + response, err := s.processInternalRequest(ctx, messaging.GetNetworkFeeRequest, &messaging.RequestContent{GetNetworkFeeRequest: request}) + return response.GetNetworkFeeResponse, err } func (s *server) GetPartnerConfiguration(ctx context.Context, request *partnerv1alpha.GetPartnerConfigurationRequest) (*partnerv1alpha.GetPartnerConfigurationResponse, error) { - response, err := s.processInternalRequest(ctx, messaging.GetPartnerConfigurationRequest, &messaging.RequestContent{GetPartnerConfigurationRequest: *request}) - return &response.GetPartnerConfigurationResponse, err + response, err := s.processInternalRequest(ctx, messaging.GetPartnerConfigurationRequest, &messaging.RequestContent{GetPartnerConfigurationRequest: request}) + return response.GetPartnerConfigurationResponse, err } func (s *server) ActivityProductList(ctx context.Context, request *activityv1alpha.ActivityProductListRequest) (*activityv1alpha.ActivityProductListResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.ActivityProductListRequest, &messaging.RequestContent{ActivityProductListRequest: *request}) - return &response.ActivityProductListResponse, err + response, err := s.processExternalRequest(ctx, messaging.ActivityProductListRequest, &messaging.RequestContent{ActivityProductListRequest: request}) + return response.ActivityProductListResponse, err } func (s *server) ActivitySearch(ctx context.Context, request *activityv1alpha.ActivitySearchRequest) (*activityv1alpha.ActivitySearchResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.ActivitySearchRequest, &messaging.RequestContent{ActivitySearchRequest: *request}) - return &response.ActivitySearchResponse, err + response, err := s.processExternalRequest(ctx, messaging.ActivitySearchRequest, &messaging.RequestContent{ActivitySearchRequest: request}) + return response.ActivitySearchResponse, err } func (s *server) Mint(ctx context.Context, request *bookv1alpha.MintRequest) (*bookv1alpha.MintResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.MintRequest, &messaging.RequestContent{MintRequest: *request}) - return &response.MintResponse, err + response, err := s.processExternalRequest(ctx, messaging.MintRequest, &messaging.RequestContent{MintRequest: request}) + return response.MintResponse, err } func (s *server) Validation(ctx context.Context, request *bookv1alpha.ValidationRequest) (*bookv1alpha.ValidationResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.ValidationRequest, &messaging.RequestContent{ValidationRequest: *request}) - return &response.ValidationResponse, err + response, err := s.processExternalRequest(ctx, messaging.ValidationRequest, &messaging.RequestContent{ValidationRequest: request}) + return response.ValidationResponse, err } func (s *server) TransportSearch(ctx context.Context, request *transportv1alpha.TransportSearchRequest) (*transportv1alpha.TransportSearchResponse, error) { - response, err := s.processExternalRequest(ctx, messaging.TransportSearchRequest, &messaging.RequestContent{TransportSearchRequest: *request}) - return &response.TransportSearchResponse, err + response, err := s.processExternalRequest(ctx, messaging.TransportSearchRequest, &messaging.RequestContent{TransportSearchRequest: request}) + return response.TransportSearchResponse, err } -func (s *server) processInternalRequest(ctx context.Context, requestType messaging.MessageType, request *messaging.RequestContent) (messaging.ResponseContent, error) { +func (s *server) processInternalRequest(ctx context.Context, requestType messaging.MessageType, request *messaging.RequestContent) (*messaging.ResponseContent, error) { ctx, span := s.tracer.Start(ctx, "server.processInternalRequest", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() service, registered := s.serviceRegistry.GetService(requestType) if !registered { - return messaging.ResponseContent{}, fmt.Errorf("%v: %s", messaging.ErrUnsupportedRequestType, requestType) + return nil, fmt.Errorf("%w: %s", messaging.ErrUnsupportedRequestType, requestType) } response, _, err := service.Call(ctx, request) return response, err } -func (s *server) processExternalRequest(ctx context.Context, requestType messaging.MessageType, request *messaging.RequestContent) (messaging.ResponseContent, error) { +func (s *server) processExternalRequest(ctx context.Context, requestType messaging.MessageType, request *messaging.RequestContent) (*messaging.ResponseContent, error) { ctx, span := s.tracer.Start(ctx, "server.processExternalRequest", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - err, md := s.processMetadata(ctx, s.tracer.TraceIDForSpan(span)) + md, err := s.processMetadata(ctx, s.tracer.TraceIDForSpan(span)) if err != nil { - return messaging.ResponseContent{}, fmt.Errorf("error processing metadata: %v", err) + return nil, fmt.Errorf("error processing metadata: %w", err) } m := &messaging.Message{ @@ -196,17 +192,20 @@ func (s *server) processExternalRequest(ctx context.Context, requestType messagi }, Metadata: md, } - response, err := s.processor.ProcessOutbound(ctx, *m) + response, err := s.processor.ProcessOutbound(ctx, m) + if err != nil { + return &messaging.ResponseContent{}, fmt.Errorf("error processing outbound request: %w", err) + } response.Metadata.Stamp(fmt.Sprintf("%s-%s", s.Checkpoint(), "processed")) - grpc.SendHeader(ctx, response.Metadata.ToGrpcMD()) - return response.Content.ResponseContent, err //TODO set specific errors according to https://grpc.github.io/grpc/core/md_doc_statuscodes.html ? + err = grpc.SendHeader(ctx, response.Metadata.ToGrpcMD()) + return &response.Content.ResponseContent, err // TODO set specific errors according to https://grpc.github.io/grpc/core/md_doc_statuscodes.html ? } -func (s *server) processMetadata(ctx context.Context, id trace.TraceID) (error, metadata.Metadata) { +func (s *server) processMetadata(ctx context.Context, id trace.TraceID) (metadata.Metadata, error) { md := metadata.Metadata{ RequestID: id.String(), } md.Stamp(fmt.Sprintf("%s-%s", s.Checkpoint(), "received")) err := md.ExtractMetadata(ctx) - return err, md + return md, err } diff --git a/internal/tracing/exporter.go b/internal/tracing/exporter.go index 5d22a06c..5dd180ef 100644 --- a/internal/tracing/exporter.go +++ b/internal/tracing/exporter.go @@ -8,20 +8,22 @@ package tracing import ( "context" "fmt" + "time" + "github.com/chain4travel/camino-messenger-bot/config" utils "github.com/chain4travel/camino-messenger-bot/utils/tls" - "time" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/sdk/trace" ) -const exportTimeout = 10 * time.Second -const exporterInstantiationTimeout = 5 * time.Second +const ( + exportTimeout = 10 * time.Second + exporterInstantiationTimeout = 5 * time.Second +) func newExporter(cfg *config.TracingConfig) (trace.SpanExporter, error) { - var client otlptrace.Client opts := []otlptracegrpc.Option{ otlptracegrpc.WithEndpoint(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)), @@ -32,7 +34,7 @@ func newExporter(cfg *config.TracingConfig) (trace.SpanExporter, error) { } else { creds, err := utils.LoadTLSCredentials(cfg.CertFile, cfg.KeyFile) if err != nil { - return nil, fmt.Errorf("could not load TLS keys: %s", err) + return nil, fmt.Errorf("could not load TLS keys: %w", err) } opts = append(opts, otlptracegrpc.WithTLSCredentials(creds)) } diff --git a/internal/tracing/nooptracer.go b/internal/tracing/nooptracer.go index 7eec09d2..4a32fa4c 100644 --- a/internal/tracing/nooptracer.go +++ b/internal/tracing/nooptracer.go @@ -31,8 +31,8 @@ func (n *noopTracer) Shutdown() error { } // TraceIDForSpan returns a random trace ID in tha case of noopTracer. A non-empty trace ID is required for the span to be exported. -func (n *noopTracer) TraceIDForSpan(trace.Span) trace.TraceID { +func (n *noopTracer) TraceIDForSpan(_ trace.Span) trace.TraceID { traceID := trace.TraceID{} - rand.Read(traceID[:]) + _, _ = rand.Read(traceID[:]) return traceID } diff --git a/internal/tracing/tracer.go b/internal/tracing/tracer.go index a97436a4..910b08fe 100644 --- a/internal/tracing/tracer.go +++ b/internal/tracing/tracer.go @@ -34,6 +34,7 @@ type tracer struct { func (t *tracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { return t.tp.Tracer("").Start(ctx, spanName, opts...) } + func (t *tracer) Shutdown() error { ctx, cancel := context.WithTimeout(context.Background(), tracerProviderShutdownTimeout) defer cancel() diff --git a/internal/tvm/client.go b/internal/tvm/client.go index 5af69863..bd930ff4 100644 --- a/internal/tvm/client.go +++ b/internal/tvm/client.go @@ -9,12 +9,12 @@ import ( "context" "errors" "fmt" - "github.com/ava-labs/hypersdk/codec" "time" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" "github.com/ava-labs/hypersdk/chain" + "github.com/ava-labs/hypersdk/codec" "github.com/ava-labs/hypersdk/pubsub" "github.com/ava-labs/hypersdk/rpc" "github.com/chain4travel/camino-messenger-bot/config" diff --git a/scripts/build_test.sh b/scripts/build_test.sh new file mode 100755 index 00000000..cb58bad3 --- /dev/null +++ b/scripts/build_test.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# Directory above this script +cd "$( dirname "${BASH_SOURCE[0]}" )"; cd ..; pwd + +# build the Go application by calling build.sh and check if failed +if ! ./scripts/build.sh; then + exit 1 +fi + +go test -shuffle=on -race -timeout="${TIMEOUT:-120s}" -coverprofile="coverage.out" -covermode="atomic" $(go list ./...) diff --git a/scripts/lint.sh b/scripts/lint.sh new file mode 100755 index 00000000..6f827f2b --- /dev/null +++ b/scripts/lint.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -euo pipefail + +# Function to check if a command exists +golangci_lint_installed() { + golangci-lint --version >/dev/null 2>&1 +} + +# Function to install golangci-lint on Ubuntu +install_golangci_lint() { + echo "Installing golangci-lint..." + go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.1 +} + +# Check if golangci-lint is installed +if golangci_lint_installed ; then + echo "golangci-lint is already installed." +else + echo "golangci-lint is not installed." + install_golangci_lint +fi + +# Run golangci-lint +echo "Running golangci-lint..." +golangci-lint run --config .golangci.yml diff --git a/scripts/mock.gen.sh b/scripts/mock.gen.sh new file mode 100755 index 00000000..b9fc4015 --- /dev/null +++ b/scripts/mock.gen.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e + +if ! [[ "$0" =~ scripts/mock.gen.sh ]]; then + echo "must be run from repository root" + exit 255 +fi + +if ! command -v mockgen &> /dev/null +then + echo "mockgen not found, installing..." + go install -v go.uber.org/mock/mockgen@v0.4.0 +fi + +# tuples of (source interface import path, comma-separated interface names, output file path) +input="scripts/mocks.mockgen.txt" +while IFS= read -r line +do + IFS='=' read src_import_path interface_name output_path <<< "${line}" + package_name=$(basename $(dirname $output_path)) + [[ $src_import_path == \#* ]] && continue + echo "Generating ${output_path}..." + mockgen -package=${package_name} -destination=${output_path} ${src_import_path} ${interface_name} +done < "$input" + +echo "SUCCESS" diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt new file mode 100644 index 00000000..4f607a7a --- /dev/null +++ b/scripts/mocks.mockgen.txt @@ -0,0 +1,5 @@ +github.com/chain4travel/camino-messenger-bot/internal/compression=Decompressor=internal/compression/mock_decompress.go +github.com/chain4travel/camino-messenger-bot/internal/matrix=Client=internal/matrix/mock_room_handler.go +buf.build/gen/go/chain4travel/camino-messenger-protocol/grpc/go/cmp/services/activity/v1alpha/activityv1alphagrpc=ActivityProductListServiceClient=internal/messaging/mock_list_grpc.pb.go +github.com/chain4travel/camino-messenger-bot/internal/messaging=ServiceRegistry=internal/messaging/mock_service_registry.go +github.com/chain4travel/camino-messenger-bot/internal/messaging=Messenger=internal/messaging/mock_messenger.go diff --git a/utils/tls/tls.go b/utils/tls/tls.go index abbaf14d..5eed4d6b 100644 --- a/utils/tls/tls.go +++ b/utils/tls/tls.go @@ -15,6 +15,7 @@ func LoadTLSCredentials(serverCertFile, serverKeyFile string) (credentials.Trans return nil, err } config := &tls.Config{ + MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{serverCert}, ClientAuth: tls.NoClientCert, } @@ -31,7 +32,8 @@ func LoadCATLSCredentials(caCertFile string) (credentials.TransportCredentials, return nil, fmt.Errorf("failed to add CA's certificate") } config := &tls.Config{ - RootCAs: certPool, + MinVersion: tls.VersionTLS12, + RootCAs: certPool, } return credentials.NewTLS(config), nil }