Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests using mocks to Go client #62

13 changes: 13 additions & 0 deletions go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,16 @@ You can run a command similar to the following::
go run . --host=<cloud.hostname> --port=443 --query="SELECT * FROM \"Samples\".\"samples.dremio.com\".\"NYC-taxi-trips\"" --tls --pat=<mypat> --project_id=<myprojectid>
```
Here we're querying for a dataset called `NYC-taxi-trips`, in a source called `Samples`, in the `samples.dremio.com` folder.

## Tests

To run the tests, you'll need a flight client mock class. This class is generated using mockgen. To aid in this process,
we created a script that generates the mock class and runs all tests. You can run the script using the following command:
```bash
./run_tests.sh
```

If the mock class is already generated you can alternatively use the following command:
```go
go test -v
```
63 changes: 32 additions & 31 deletions go/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
package main

import (
"arrow-flight-client-example/interfaces"
"context"
"crypto/tls"
"fmt"
"github.com/apache/arrow-go/v18/arrow/flight"
flightgen "github.com/apache/arrow-go/v18/arrow/flight/gen/flight"
"github.com/apache/arrow-go/v18/arrow/memory"
"log"
"net"

"github.com/apache/arrow-go/v18/arrow/flight"
flightgen "github.com/apache/arrow-go/v18/arrow/flight/gen/flight"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/docopt/docopt-go"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand All @@ -48,25 +49,21 @@ Options:
--certs=<path> Path to trusted certificates for encrypted connection.
--project_id=<project_id> Dremio project ID`

func WrapRecordReader(stream flight.FlightService_DoGetClient) (interfaces.RecordReader, error) {
return flight.NewRecordReader(stream)
}

func main() {
args, err := docopt.ParseDoc(usage)
var config struct {
Host string
Port string
Pat string
User string
Pass string
Query string
TLS bool `docopt:"--tls"`
Certs string
ProjectID string `docopt:"--project_id"`
}
if err != nil {
log.Fatalf("error parsing arguments: %v", err)
}

var config FlightConfig
if err := args.Bind(&config); err != nil {
log.Fatalf("error binding arguments: %v", err)
}

var creds credentials.TransportCredentials
if config.TLS {
log.Println("[INFO] Enabling TLS Connection.")
Expand Down Expand Up @@ -96,47 +93,53 @@ func main() {
client, err := flight.NewClientWithMiddleware(
net.JoinHostPort(config.Host, config.Port),
nil,
[]flight.ClientMiddleware{
flight.NewClientCookieMiddleware(),
},
[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()},
grpc.WithTransportCredentials(creds),
)
if err != nil {
log.Fatal(err)
}
defer client.Close()

if err := run(config, client, WrapRecordReader); err != nil {
log.Fatal(err)
}
}

func run(config FlightConfig, client flight.Client,
readerCreator func(flight.FlightService_DoGetClient) (interfaces.RecordReader, error),
) error {

// Two WLM settings can be provided upon initial authentication with the dremio
// server flight endpoint:
// - routing-tag
// - routing-queue
ctx := metadata.NewOutgoingContext(context.TODO(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have this somewhere in the docs?
In golang in general the context should be passed outside. For now it's fine to keep it here, but it would be great to add docs link where this is described.

metadata.Pairs("routing-tag", "test-routing-tag", "routing-queue", "Low Cost User Queries"))

var err error
if config.Pat != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", fmt.Sprintf("Bearer %s", config.Pat))
log.Println("[INFO] Using PAT.")

// If project_id is provided, set it in session options
if config.ProjectID != "" {
log.Println("[INFO] Project ID added to sessions options.")
err = setSessionOptions(ctx, client, config.ProjectID)
if err != nil {
log.Fatalf("Failed to set session options: %v", err)
return fmt.Errorf("failed to set session options: %v", err)
}

// Close the session once the query is done
defer client.CloseSession(ctx, &flight.CloseSessionRequest{})
}
} else {
if ctx, err = client.AuthenticateBasicToken(ctx, config.User, config.Pass); err != nil {
log.Fatal(err)
return fmt.Errorf("failed to authenticate user: %v", err)
}
log.Println("[INFO] Authentication was successful.")
}

if config.Query == "" {
return
return nil
}

// Once successful, the context object now contains the credentials, use it for subsequent calls.
Expand All @@ -152,32 +155,32 @@ func main() {
// Retrieve the schema of the result set
sc, err := client.GetSchema(ctx, desc)
if err != nil {
log.Fatal(err)
return fmt.Errorf("failed to get schema: %v", err)
}
log.Println("[INFO] GetSchema was successful.")

schema, err := flight.DeserializeSchema(sc.GetSchema(), memory.DefaultAllocator)
if err != nil {
log.Fatal(err)
return fmt.Errorf("failed to deserialize schema: %v", err)
}
log.Println("[INFO] Schema:", schema)

// Get the FlightInfo message to retrieve the ticket corresponding to the query result set
info, err := client.GetFlightInfo(ctx, desc)
if err != nil {
log.Fatal(err)
return fmt.Errorf("failed to get flight info: %v", err)
}
log.Println("[INFO] GetFlightInfo was successful.")

// retrieve the result set as a stream of Arrow record batches.
stream, err := client.DoGet(ctx, info.Endpoint[0].Ticket)
if err != nil {
log.Fatal(err)
return fmt.Errorf("failed to get flight stream: %v", err)
}

rdr, err := flight.NewRecordReader(stream)
rdr, err := readerCreator(stream)
if err != nil {
log.Fatal(err)
return fmt.Errorf("failed to create record reader: %v", err)
}
defer rdr.Release()

Expand All @@ -187,25 +190,23 @@ func main() {
defer rec.Release()
log.Println(rec)
}
return nil
}

func setSessionOptions(ctx context.Context, client flight.Client, projectID string) error {
projectIdSessionOption, err := flight.NewSessionOptionValue(projectID)
if err != nil {
return fmt.Errorf("failed to create session option: %v", err)
}

sessionOptionsRequest := flight.SetSessionOptionsRequest{
SessionOptions: map[string]*flight.SessionOptionValue{
"project_id": &projectIdSessionOption,
},
}

_, err = client.SetSessionOptions(ctx, &sessionOptionsRequest)
if err != nil {
return fmt.Errorf("failed to set session options: %v", err)
valterfrancisco-dremio marked this conversation as resolved.
Show resolved Hide resolved
}

log.Printf("[INFO] Session options set with project_id: %s", projectID)
return nil
}
Loading