Skip to content

Commit

Permalink
Refactor GoFlightClient to simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
valterfrancisco-dremio committed Dec 26, 2024
1 parent e4e6006 commit 725b002
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 130 deletions.
28 changes: 12 additions & 16 deletions go/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"log"
"net"

"arrow-flight-client-example/implementations"
"arrow-flight-client-example/interfaces"
"github.com/apache/arrow-go/v18/arrow/flight"
flightgen "github.com/apache/arrow-go/v18/arrow/flight/gen/flight"
Expand Down Expand Up @@ -87,7 +86,7 @@ func main() {
creds = insecure.NewCredentials()
}

rawClient, err := flight.NewClientWithMiddleware(
client, err := flight.NewClientWithMiddleware(
net.JoinHostPort(config.Host, config.Port),
nil,
[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
Expand All @@ -96,17 +95,14 @@ func main() {
if err != nil {
log.Fatal(err)
}
defer rawClient.Close()
defer client.Close()

// GoFlightClient and the TestableClient interface are simple wrappers for flight.Client to provide testing and modularity.
abstractClient := &implementations.GoFlightClient{Client: rawClient}
if err := run(config, abstractClient, interfaces.WrapRecordReader); err != nil {
if err := run(config, client, interfaces.WrapRecordReader); err != nil {
log.Fatal(err)
}

}

func run(config interfaces.FlightConfig, abstractClient interfaces.TestableClient,
func run(config interfaces.FlightConfig, client flight.Client,
readerCreator func(flight.FlightService_DoGetClient) (interfaces.RecordReader, error),
) error {

Expand All @@ -124,15 +120,15 @@ func run(config interfaces.FlightConfig, abstractClient interfaces.TestableClien

if config.ProjectID != "" {
log.Println("[INFO] Project ID added to sessions options.")
err = setSessionOptions(ctx, abstractClient, config.ProjectID)
err = setSessionOptions(ctx, client, config.ProjectID)
if err != nil {
return fmt.Errorf("failed to set session options: %v", err)
}
// Close the session once the query is done
defer abstractClient.CloseSession(ctx, &flight.CloseSessionRequest{})
defer client.CloseSession(ctx, &flight.CloseSessionRequest{})
}
} else {
if ctx, err = abstractClient.AuthenticateBasicToken(ctx, config.User, config.Pass); err != nil {
if ctx, err = client.AuthenticateBasicToken(ctx, config.User, config.Pass); err != nil {
return fmt.Errorf("failed to authenticate user: %v", err)
}
log.Println("[INFO] Authentication was successful.")
Expand All @@ -153,7 +149,7 @@ func run(config interfaces.FlightConfig, abstractClient interfaces.TestableClien
// ctx = metadata.AppendToOutgoingContext(ctx, "schema", "test.schema")

// Retrieve the schema of the result set
sc, err := abstractClient.GetSchema(ctx, desc)
sc, err := client.GetSchema(ctx, desc)
if err != nil {
return fmt.Errorf("failed to get schema: %v", err)
}
Expand All @@ -166,14 +162,14 @@ func run(config interfaces.FlightConfig, abstractClient interfaces.TestableClien
log.Println("[INFO] Schema:", schema)

// Get the FlightInfo message to retrieve the ticket corresponding to the query result set
info, err := abstractClient.GetFlightInfo(ctx, desc)
info, err := client.GetFlightInfo(ctx, desc)
if err != nil {
return fmt.Errorf("failed to get flight info: %v", err)
}
log.Println("[INFO] GetFlightInfo was successful.")

// retrieve the result set as a stream of Arrow record batches.
stream, err := abstractClient.DoGet(ctx, info.Endpoint[0].Ticket)
stream, err := client.DoGet(ctx, info.Endpoint[0].Ticket)
if err != nil {
return fmt.Errorf("failed to get flight stream: %v", err)
}
Expand All @@ -193,7 +189,7 @@ func run(config interfaces.FlightConfig, abstractClient interfaces.TestableClien
return nil
}

func setSessionOptions(ctx context.Context, client interfaces.TestableClient, projectID string) error {
func setSessionOptions(ctx context.Context, client flight.Client, projectID string) error {
projectIdSessionOption, err := flight.NewSessionOptionValue(projectID)
if err != nil {
return fmt.Errorf("failed to create session option: %v", err)
Expand All @@ -205,7 +201,7 @@ func setSessionOptions(ctx context.Context, client interfaces.TestableClient, pr
}
_, err = client.SetSessionOptions(ctx, &sessionOptionsRequest)
if err != nil {
return fmt.Errorf("set session options: %v", err)
return fmt.Errorf("failed to set session options: %v", err)
}
log.Printf("[INFO] Session options set with project_id: %s", projectID)
return nil
Expand Down
16 changes: 8 additions & 8 deletions go/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestUsernamePassAuth(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

mockClient.EXPECT().
AuthenticateBasicToken(gomock.Any(), "testuser", "testpass").
Expand All @@ -46,7 +46,7 @@ func TestPATAuth(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

mockClient.EXPECT().
SetSessionOptions(gomock.Any(), gomock.Any()).
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestRun(t *testing.T) {

mockStream := implementations.NewMockFlightService_DoGetClient(ctrl)

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

mockClient.EXPECT().
AuthenticateBasicToken(gomock.Any(), "testuser", "testpass").
Expand Down Expand Up @@ -192,7 +192,7 @@ func TestRunWithPAT(t *testing.T) {

mockStream := implementations.NewMockFlightService_DoGetClient(ctrl)

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

mockClient.EXPECT().
GetSchema(gomock.Any(), gomock.Any()).
Expand Down Expand Up @@ -279,7 +279,7 @@ func TestRunWithPATNoProjectID(t *testing.T) {

mockStream := implementations.NewMockFlightService_DoGetClient(ctrl)

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

mockClient.EXPECT().
GetSchema(gomock.Any(), gomock.Any()).
Expand Down Expand Up @@ -318,7 +318,7 @@ func TestInvalidCredentials(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

expectedErr := status.Error(codes.Unauthenticated, "failed to authenticate user: rpc error: code = "+
"Unauthenticated desc = Unable to authenticate user dremio, exception: Login failed: Invalid username or "+
Expand Down Expand Up @@ -357,7 +357,7 @@ func TestInvalidHost(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

expectedErr := status.Error(codes.Unauthenticated, "failed to authenticate user: rpc error: code = "+
"Unavailable desc = name resolver error: produced zero addresses")
Expand Down Expand Up @@ -395,7 +395,7 @@ func TestInvalidPort(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockClient := NewMockFlightClient(ctrl)
mockClient := NewMockClient(ctrl)

expectedErr := status.Error(codes.Unauthenticated, "failed to authenticate user: rpc error: code = "+
"Unavailable desc = connection error: desc = \"transport: Error while dialing: dial tcp: lookup tcp/320o: unknown port\"")
Expand Down
44 changes: 0 additions & 44 deletions go/implementations/go_flight_client.go

This file was deleted.

19 changes: 0 additions & 19 deletions go/interfaces/flight_client.go

This file was deleted.

Loading

0 comments on commit 725b002

Please sign in to comment.