From cfceb512e9e25e766030530bef9eefcd87eddad7 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Wed, 20 Sep 2023 11:21:01 -0600 Subject: [PATCH] PECO-1054 Expose Arrow batches to users, part two Updated FetchableItems interface to return an instance of OutputType instead of a slice of output type. Created interfaces SparkArrowBatch and SparkArrowRecord and implementations of each. This also changed the 1:1 ratio of batch instances to arrow records. A SparkArrowBatch can contain multiple arrow records. Created BatchIterator interface and implementation and switched arrowRowScanner to use BatchIterator instead of BatchLoader Created RowValues interface and implementation as a container for the currently loaded values for a set of rows. Updated the behaviour of fetchable items cloudURL and localBatch to de-serialize the arrow records as part of fetching, rather than carry around the raw bytes for later de-serialization. Also eliminated the cloud fetch code that was de-serializing the arrow batch then serializing each record individually to create one batch instance per record. Removed chunkedByteReader and replaced with io.MultiReader Normalized use of row number so that there is no need to track the index of the row in the current batch. Signed-off-by: Raymond Cypher --- internal/fetcher/fetcher.go | 7 +- internal/fetcher/fetcher_test.go | 8 +- internal/rows/arrowbased/arrowRows.go | 253 ++++++++------ internal/rows/arrowbased/arrowRows_test.go | 323 ++++++++++++------ internal/rows/arrowbased/batchloader.go | 269 +++++++++------ internal/rows/arrowbased/batchloader_test.go | 116 ++++--- internal/rows/arrowbased/chunkedByteReader.go | 98 ------ .../rows/arrowbased/chunkedByteReader_test.go | 61 ---- internal/rows/arrowbased/columnValues.go | 65 ++++ internal/rows/arrowbased/errors.go | 4 +- internal/rows/columnbased/columnRows.go | 3 +- internal/rows/errors.go | 2 + internal/rows/rows.go | 47 ++- internal/rows/rows_test.go | 46 +-- internal/rows/rowscanner/rowScanner.go | 2 +- 15 files changed, 719 insertions(+), 585 deletions(-) delete mode 100644 internal/rows/arrowbased/chunkedByteReader.go delete mode 100644 internal/rows/arrowbased/chunkedByteReader_test.go diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go index cce84ad..2f53754 100644 --- a/internal/fetcher/fetcher.go +++ b/internal/fetcher/fetcher.go @@ -9,7 +9,7 @@ import ( ) type FetchableItems[OutputType any] interface { - Fetch(ctx context.Context) ([]OutputType, error) + Fetch(ctx context.Context) (OutputType, error) } type Fetcher[OutputType any] interface { @@ -151,10 +151,7 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in return } else { f.logger().Debug().Msgf("concurrent fetcher worker %d item loaded", workerIndex) - for i := range result { - r := result[i] - f.outChan <- r - } + f.outChan <- result } } else { f.logger().Debug().Msgf("concurrent fetcher ending %d", workerIndex) diff --git a/internal/fetcher/fetcher_test.go b/internal/fetcher/fetcher_test.go index 76be59a..dbe6ced 100644 --- a/internal/fetcher/fetcher_test.go +++ b/internal/fetcher/fetcher_test.go @@ -30,13 +30,13 @@ func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) { return outputs, nil } -var _ FetchableItems[*mockOutput] = (*mockFetchableItem)(nil) +var _ FetchableItems[[]*mockOutput] = (*mockFetchableItem)(nil) func TestConcurrentFetcher(t *testing.T) { t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) { ctx := context.Background() - inputChan := make(chan FetchableItems[*mockOutput], 10) + inputChan := make(chan FetchableItems[[]*mockOutput], 10) for i := 0; i < 10; i++ { item := mockFetchableItem{item: i, wait: 1 * time.Second} inputChan <- &item @@ -57,7 +57,7 @@ func TestConcurrentFetcher(t *testing.T) { var results []*mockOutput for result := range outChan { - results = append(results, result) + results = append(results, result...) } // Check if the fetcher returned the expected results @@ -87,7 +87,7 @@ func TestConcurrentFetcher(t *testing.T) { defer cancel() // Create an input channel - inputChan := make(chan FetchableItems[*mockOutput], 3) + inputChan := make(chan FetchableItems[[]*mockOutput], 3) for i := 0; i < 3; i++ { item := mockFetchableItem{item: i, wait: 1 * time.Second} inputChan <- &item diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 6504db5..19a7659 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "database/sql/driver" + "io" "time" "github.com/apache/arrow/go/v12/arrow" @@ -18,18 +19,17 @@ import ( "github.com/pkg/errors" ) -type recordReader interface { - NewRecordFromBytes(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) -} - -type valueContainerMaker interface { - makeColumnValuesContainers(ars *arrowRowScanner) error +// Abstraction for a set of arrow records +type SparkArrowBatch interface { + rowscanner.Delimiter + Next() (SparkArrowRecord, error) + Close() } -type sparkArrowBatch struct { +// Abstraction for an arrow record +type SparkArrowRecord interface { rowscanner.Delimiter - arrowRecordBytes []byte - hasSchema bool + arrow.Record } type timeStampFn func(arrow.Timestamp) time.Time @@ -43,7 +43,6 @@ type colInfo struct { // arrowRowScanner handles extracting values from arrow records type arrowRowScanner struct { rowscanner.Delimiter - recordReader valueContainerMaker // configuration of different arrow options for retrieving results @@ -58,11 +57,10 @@ type arrowRowScanner struct { // database types for the columns colInfo []colInfo - // a TRowSet contains multiple arrow batches - currentBatch *sparkArrowBatch + currentBatch SparkArrowBatch - // Values for each column - columnValues []columnValues + // Currently loaded field values for a set of rows + rowValues RowValues // function to convert arrow timestamp when using native arrow format toTimestampFn timeStampFn @@ -76,7 +74,7 @@ type arrowRowScanner struct { resultFormat cli_service.TSparkRowSetType - BatchLoader + batchIterator BatchIterator } // Make sure arrowRowScanner fulfills the RowScanner interface @@ -117,15 +115,19 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp var bl BatchLoader var err2 dbsqlerr.DBError if len(rowSet.ResultLinks) > 0 { - bl, err2 = NewCloudBatchLoader(context.Background(), rowSet.ResultLinks, cfg) + bl, err2 = NewCloudBatchLoader(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) } else { - bl, err2 = NewLocalBatchLoader(context.Background(), rowSet.ArrowBatches, cfg) + bl, err2 = NewLocalBatchLoader(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) } - if err2 != nil { return nil, err2 } + bi, err := NewBatchIterator(bl) + if err != nil { + return nil, err2 + } + var location *time.Location = time.UTC if cfg != nil { if cfg.Location != nil { @@ -134,10 +136,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp } rs := &arrowRowScanner{ - Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)), - recordReader: sparkRecordReader{ - ctx: ctx, - }, + Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)), valueContainerMaker: &arrowValueContainerMaker{}, ArrowConfig: arrowConfig, arrowSchemaBytes: schemaBytes, @@ -147,7 +146,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp DBSQLLogger: logger, location: location, resultFormat: *resultSetMetadata.ResultFormat, - BatchLoader: bl, + batchIterator: bi, } return rs, nil @@ -155,11 +154,16 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp // Close is called when the Rows instance is closed. func (ars *arrowRowScanner) Close() { - // release any retained arrow arrays - for i := range ars.columnValues { - if ars.columnValues[i] != nil { - ars.columnValues[i].Release() - } + if ars.rowValues != nil { + ars.rowValues.Close() + } + + if ars.batchIterator != nil { + ars.batchIterator.Close() + } + + if ars.currentBatch != nil { + ars.currentBatch.Close() } } @@ -189,23 +193,19 @@ var intervalTypes map[cli_service.TTypeId]struct{} = map[cli_service.TTypeId]str // a buffer held in dest. func (ars *arrowRowScanner) ScanRow( destination []driver.Value, - rowIndex int64) dbsqlerr.DBError { + rowNumber int64) dbsqlerr.DBError { // load the error batch for the specified row, if necessary - err := ars.loadBatchFor(rowIndex) + err := ars.loadBatchFor(rowNumber) if err != nil { return err } - var rowInBatchIndex int = int(rowIndex - ars.currentBatch.Start()) - // if no location is provided default to UTC if ars.location == nil { ars.location = time.UTC } - nCols := len(ars.columnValues) - // loop over the destination slice filling in values for i := range destination { // clear the destination @@ -213,7 +213,7 @@ func (ars *arrowRowScanner) ScanRow( // if there is a corresponding column and the value for the specified row // is not null we put the value in the destination - if i < nCols && !ars.columnValues[i].IsNull(rowInBatchIndex) { + if !ars.rowValues.IsNull(i, rowNumber) { col := ars.colInfo[i] dbType := col.dbType @@ -227,7 +227,7 @@ func (ars *arrowRowScanner) ScanRow( // get the value from the column values holder var err1 error - destination[i], err1 = ars.columnValues[i].Value(rowInBatchIndex) + destination[i], err1 = ars.rowValues.Value(i, rowNumber) if err1 != nil { err = dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsColumnValue(col.name), err1) } @@ -243,64 +243,78 @@ func isIntervalType(typeId cli_service.TTypeId) bool { } // loadBatchFor loads the batch containing the specified row if necessary -func (ars *arrowRowScanner) loadBatchFor(rowIndex int64) dbsqlerr.DBError { +func (ars *arrowRowScanner) loadBatchFor(rowNumber int64) dbsqlerr.DBError { - if ars == nil || ars.BatchLoader == nil { + if ars == nil || ars.batchIterator == nil { return dbsqlerrint.NewDriverError(context.Background(), errArrowRowsNoArrowBatches, nil) } + // if the batch already loaded we can just return - if ars.currentBatch != nil && ars.currentBatch.Contains(rowIndex) && ars.columnValues != nil { + if ars.rowValues != nil && ars.rowValues.Contains(rowNumber) { return nil } - batch, err := ars.GetBatchFor(rowIndex) + // check for things like going backwards, rowNumber < 0, etc. + err := ars.validateRowNumber(rowNumber) if err != nil { return err } - // set up the column values containers - if ars.columnValues == nil { - err := ars.makeColumnValuesContainers(ars) + // Find the batch containing the row number, if necessary + for ars.currentBatch == nil || !ars.currentBatch.Contains(rowNumber) { + batch, err := ars.batchIterator.Next() if err != nil { - return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsMakeColumnValueContainers, err) + return err } - } - var r arrow.Record - if ars.resultFormat == cli_service.TSparkRowSetType_ARROW_BASED_SET { - r, err = ars.NewRecordFromBytes(ars.arrowSchemaBytes, *batch) - } else if ars.resultFormat == cli_service.TSparkRowSetType_URL_BASED_SET { - r, err = ars.NewRecordFromBytes(nil, *batch) - } - if err != nil { - ars.Err(err).Msg(errArrowRowsUnableToReadBatch) - return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsUnableToReadBatch, err) + ars.currentBatch = batch } - defer r.Release() + // Get the next arrow record from the current batch + sar, err2 := ars.currentBatch.Next() + if err2 != nil { + ars.Err(err2).Msg(errArrowRowsUnableToReadBatch) + return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsUnableToReadBatch, err2) + } - // for each column we want to create an arrow array specific to the data type - for i, col := range r.Columns() { - col.Retain() - defer col.Release() + defer sar.Release() - colData := col.Data() - colData.Retain() - defer colData.Release() + // set up the column values containers + if ars.rowValues == nil { + err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(sar.Start(), sar.Count())) + if err != nil { + return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsMakeColumnValueContainers, err) + } + } - colValsHolder := ars.columnValues[i] + // for each column we want to create an arrow array specific to the data type + for i, col := range sar.Columns() { + func() { + col.Retain() + defer col.Release() - // release the arrow array already held - colValsHolder.Release() + colData := col.Data() + colData.Retain() + defer colData.Release() - err := colValsHolder.SetValueArray(colData) - if err != nil { - ars.Error().Msg(err.Error()) - } + err := ars.rowValues.SetColumnValues(i, colData) + if err != nil { + ars.Error().Msg(err.Error()) + } + }() } - ars.currentBatch = batch + // Update the delimiter in rowValues to reflect the currently loaded set of rows + ars.rowValues.SetDelimiter(rowscanner.NewDelimiter(sar.Start(), sar.Count())) + return nil +} +// Check that the row number falls within the range of this row scanner and that +// it is not moving backwards. +func (ars *arrowRowScanner) validateRowNumber(rowNumber int64) dbsqlerr.DBError { + if rowNumber < 0 || rowNumber > ars.End() || (ars.currentBatch != nil && ars.currentBatch.Direction(rowNumber) == rowscanner.DirBack) { + return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsInvalidRowNumber(rowNumber), nil) + } return nil } @@ -513,7 +527,7 @@ func tGetResultSetMetadataRespToArrowSchema(resultSetMetadata *cli_service.TGetR return nil, nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsSerializeSchema, err) } } else { - br := &chunkedByteReader{chunks: [][]byte{schemaBytes}} + br := bytes.NewReader(schemaBytes) rdr, err := ipc.NewReader(br) if err != nil { return nil, nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsUnableToReadBatch, err) @@ -526,49 +540,14 @@ func tGetResultSetMetadataRespToArrowSchema(resultSetMetadata *cli_service.TGetR return schemaBytes, arrowSchema, nil } -type sparkRecordReader struct { - ctx context.Context -} - -// Make sure sparkRecordReader fulfills the recordReader interface -var _ recordReader = (*sparkRecordReader)(nil) - -func (srr sparkRecordReader) NewRecordFromBytes(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - // The arrow batches returned from the thrift server are actually a serialized arrow Record - // an arrow batch should consist of a Schema and at least one Record. - // Use a chunked byte reader to concatenate the schema bytes and the record bytes without - // having to allocate/copy slices. - - var br *chunkedByteReader - if arrowSchemaBytes == nil { - br = &chunkedByteReader{chunks: [][]byte{sparkArrowBatch.arrowRecordBytes}} - } else { - br = &chunkedByteReader{chunks: [][]byte{arrowSchemaBytes, sparkArrowBatch.arrowRecordBytes}} - } - rdr, err := ipc.NewReader(br) - if err != nil { - return nil, dbsqlerrint.NewDriverError(srr.ctx, errArrowRowsUnableToReadBatch, err) - } - defer rdr.Release() - - r, err := rdr.Read() - if err != nil { - return nil, dbsqlerrint.NewDriverError(srr.ctx, errArrowRowsUnableToReadBatch, err) - } - - r.Retain() - - return r, nil -} - type arrowValueContainerMaker struct{} var _ valueContainerMaker = (*arrowValueContainerMaker)(nil) // makeColumnValuesContainers creates appropriately typed column values holders for each column -func (vcm *arrowValueContainerMaker) makeColumnValuesContainers(ars *arrowRowScanner) error { - if ars.columnValues == nil { - ars.columnValues = make([]columnValues, len(ars.colInfo)) +func (vcm *arrowValueContainerMaker) makeColumnValuesContainers(ars *arrowRowScanner, d rowscanner.Delimiter) error { + if ars.rowValues == nil { + columnValueHolders := make([]columnValues, len(ars.colInfo)) for i, field := range ars.arrowSchema.Fields() { holder, err := vcm.makeColumnValueContainer(field.Type, ars.location, ars.toTimestampFn, &ars.colInfo[i]) if err != nil { @@ -576,9 +555,12 @@ func (vcm *arrowValueContainerMaker) makeColumnValuesContainers(ars *arrowRowSca return err } - ars.columnValues[i] = holder + columnValueHolders[i] = holder } + + ars.rowValues = NewRowValues(d, columnValueHolders) } + return nil } @@ -688,3 +670,54 @@ func (vcm *arrowValueContainerMaker) makeColumnValueContainer(t arrow.DataType, return nil, errors.Errorf(errArrowRowsUnhandledArrowType(t.String())) } } + +// Container for a set of arrow records +type sparkArrowBatch struct { + // Delimiter indicating the range of rows covered by the arrow records + rowscanner.Delimiter + arrowRecords []SparkArrowRecord +} + +var _ SparkArrowBatch = (*sparkArrowBatch)(nil) + +func (b *sparkArrowBatch) Next() (SparkArrowRecord, error) { + if len(b.arrowRecords) > 0 { + r := b.arrowRecords[0] + // remove the record from the slice as iteration is only forwards + b.arrowRecords = b.arrowRecords[1:] + return r, nil + } + + // no more records + return nil, io.EOF +} + +func (b *sparkArrowBatch) Close() { + // Release any arrow records + for i := range b.arrowRecords { + b.arrowRecords[i].Release() + } + b.arrowRecords = nil +} + +// Composite of an arrow record and a delimiter indicating +// the rows corresponding to the record. +type sparkArrowRecord struct { + rowscanner.Delimiter + arrow.Record +} + +var _ SparkArrowRecord = (*sparkArrowRecord)(nil) + +func (sar *sparkArrowRecord) Release() { + if sar.Record != nil { + sar.Record.Release() + sar.Record = nil + } +} + +func (sar *sparkArrowRecord) Retain() { + if sar.Record != nil { + sar.Record.Retain() + } +} diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index ac11b8e..2b2b754 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -16,6 +16,7 @@ import ( "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/config" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" + "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" dbsqllog "github.com/databricks/databricks-sql-go/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -238,60 +239,63 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - err := ars.makeColumnValuesContainers(ars) + err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 1)) require.Nil(t, err) var ok bool - _, ok = ars.columnValues[0].(*columnValuesTyped[*array.Boolean, bool]) + rowValues, ok := ars.rowValues.(*rowValues) assert.True(t, ok) - _, ok = ars.columnValues[1].(*columnValuesTyped[*array.Int8, int8]) + _, ok = rowValues.columnValueHolders[0].(*columnValuesTyped[*array.Boolean, bool]) assert.True(t, ok) - _, ok = ars.columnValues[2].(*columnValuesTyped[*array.Int16, int16]) + _, ok = rowValues.columnValueHolders[1].(*columnValuesTyped[*array.Int8, int8]) assert.True(t, ok) - _, ok = ars.columnValues[3].(*columnValuesTyped[*array.Int32, int32]) + _, ok = rowValues.columnValueHolders[2].(*columnValuesTyped[*array.Int16, int16]) assert.True(t, ok) - _, ok = ars.columnValues[4].(*columnValuesTyped[*array.Int64, int64]) + _, ok = rowValues.columnValueHolders[3].(*columnValuesTyped[*array.Int32, int32]) assert.True(t, ok) - _, ok = ars.columnValues[5].(*columnValuesTyped[*array.Float32, float32]) + _, ok = rowValues.columnValueHolders[4].(*columnValuesTyped[*array.Int64, int64]) assert.True(t, ok) - _, ok = ars.columnValues[6].(*columnValuesTyped[*array.Float64, float64]) + _, ok = rowValues.columnValueHolders[5].(*columnValuesTyped[*array.Float32, float32]) assert.True(t, ok) - _, ok = ars.columnValues[7].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[6].(*columnValuesTyped[*array.Float64, float64]) assert.True(t, ok) - _, ok = ars.columnValues[8].(*timestampStringValueContainer) + _, ok = rowValues.columnValueHolders[7].(*columnValuesTyped[*array.String, string]) assert.True(t, ok) - _, ok = ars.columnValues[9].(*columnValuesTyped[*array.Binary, []byte]) + _, ok = rowValues.columnValueHolders[8].(*timestampStringValueContainer) assert.True(t, ok) - _, ok = ars.columnValues[10].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[9].(*columnValuesTyped[*array.Binary, []byte]) assert.True(t, ok) - _, ok = ars.columnValues[11].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[10].(*columnValuesTyped[*array.String, string]) assert.True(t, ok) - _, ok = ars.columnValues[12].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[11].(*columnValuesTyped[*array.String, string]) assert.True(t, ok) - _, ok = ars.columnValues[13].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[12].(*columnValuesTyped[*array.String, string]) assert.True(t, ok) - _, ok = ars.columnValues[14].(*dateValueContainer) + _, ok = rowValues.columnValueHolders[13].(*columnValuesTyped[*array.String, string]) assert.True(t, ok) - _, ok = ars.columnValues[15].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[14].(*dateValueContainer) assert.True(t, ok) - _, ok = ars.columnValues[16].(*columnValuesTyped[*array.String, string]) + _, ok = rowValues.columnValueHolders[15].(*columnValuesTyped[*array.String, string]) + assert.True(t, ok) + + _, ok = rowValues.columnValueHolders[16].(*columnValuesTyped[*array.String, string]) assert.True(t, ok) }) @@ -311,17 +315,20 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - err := ars.makeColumnValuesContainers(ars) + err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 1)) require.Nil(t, err) var ok bool + rowValues, ok := ars.rowValues.(*rowValues) + assert.True(t, ok) + // timestamp - _, ok = ars.columnValues[8].(*timestampValueContainer) + _, ok = rowValues.columnValueHolders[8].(*timestampValueContainer) assert.True(t, ok) // decimal - _, ok = ars.columnValues[13].(*decimal128Container) + _, ok = rowValues.columnValueHolders[13].(*decimal128Container) assert.True(t, ok) }) @@ -410,13 +417,13 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - err := ars.makeColumnValuesContainers(ars) + err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 0)) require.Nil(t, err) dest := make([]driver.Value, 1) err = ars.ScanRow(dest, 0) require.NotNil(t, err) - assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsInvalidRowIndex(0))) + assert.True(t, strings.Contains(err.Error(), "databricks: driver error: "+errArrowRowsInvalidRowNumber(0))) }) t.Run("Close releases column values", func(t *testing.T) { @@ -442,7 +449,7 @@ func TestArrowRowScanner(t *testing.T) { ars := d.(*arrowRowScanner) var releaseCount int fc := &fakeColumnValues{fnRelease: func() { releaseCount++ }} - ars.columnValues = []columnValues{fc, fc, fc} + ars.rowValues = NewRowValues(rowscanner.NewDelimiter(0, 1), []columnValues{fc, fc, fc}) d.Close() assert.Equal(t, 3, releaseCount) }) @@ -479,31 +486,55 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - assert.Nil(t, ars.columnValues) + assert.Nil(t, ars.rowValues) - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - return fakeRecord{}, nil - }} + b1 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{ + &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 2), Record: &fakeRecord{}}, + &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(2, 3), Record: &fakeRecord{}}}} + b2 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}} + b3 := &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}} + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{b1, b2, b3}, + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) var callCount int - ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner) dbsqlerr.DBError { + ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { callCount += 1 - ars.columnValues = make([]columnValues, len(ars.arrowSchema.Fields())) + columnValueHolders := make([]columnValues, len(ars.arrowSchema.Fields())) for i := range ars.arrowSchema.Fields() { - ars.columnValues[i] = &fakeColumnValues{} + columnValueHolders[i] = &fakeColumnValues{} } + ars.rowValues = NewRowValues(rowscanner.NewDelimiter(0, 0), make([]columnValues, len(ars.arrowSchema.Fields()))) return nil }} err := ars.loadBatchFor(0) assert.Nil(t, err) - assert.NotNil(t, ars.columnValues) + assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) + assert.Equal(t, 1, callCount) + assert.Equal(t, 1, fbl.callCount) + + err = ars.loadBatchFor(1) + assert.Nil(t, err) + assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) + assert.Equal(t, 1, callCount) + assert.Equal(t, 1, fbl.callCount) + + err = ars.loadBatchFor(2) + assert.Nil(t, err) + assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) + assert.Equal(t, 1, fbl.callCount) err = ars.loadBatchFor(5) assert.Nil(t, err) - assert.NotNil(t, ars.columnValues) + assert.Equal(t, len(metadataResp.Schema.Columns), ars.rowValues.NColumns()) assert.Equal(t, 1, callCount) + assert.Equal(t, 2, fbl.callCount) }) @@ -525,19 +556,25 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - var callCount int - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - callCount += 1 - return fakeRecord{}, nil - }} + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{ + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, + }, + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) err := ars.loadBatchFor(0) assert.Nil(t, err) - assert.Equal(t, 1, callCount) + assert.Equal(t, 1, fbl.callCount) err = ars.loadBatchFor(0) assert.Nil(t, err) - assert.Equal(t, 1, callCount) + assert.Equal(t, 1, fbl.callCount) }) t.Run("loadBatch index out of bounds", func(t *testing.T) { @@ -558,19 +595,25 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - var callCount int - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - callCount += 1 - return fakeRecord{}, nil - }} + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{ + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, + }, + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) err := ars.loadBatchFor(-1) assert.NotNil(t, err) - assert.ErrorContains(t, err, errArrowRowsInvalidRowIndex(-1)) + assert.ErrorContains(t, err, errArrowRowsInvalidRowNumber(-1)) err = ars.loadBatchFor(17) assert.NotNil(t, err) - assert.ErrorContains(t, err, errArrowRowsInvalidRowIndex(17)) + assert.ErrorContains(t, err, errArrowRowsInvalidRowNumber(17)) }) t.Run("loadBatch container failure", func(t *testing.T) { @@ -591,12 +634,24 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) var ars *arrowRowScanner = d.(*arrowRowScanner) - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - return fakeRecord{}, nil - }} - ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner) dbsqlerr.DBError { - return dbsqlerrint.NewDriverError(context.TODO(), "error making containers", nil) - }} + + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{ + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, + }, + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) + + ars.valueContainerMaker = &fakeValueContainerMaker{ + fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { + return dbsqlerrint.NewDriverError(context.TODO(), "error making containers", nil) + }, + } err := ars.loadBatchFor(0) assert.NotNil(t, err) @@ -622,9 +677,19 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) var ars *arrowRowScanner = d.(*arrowRowScanner) - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - return fakeRecord{}, dbsqlerrint.NewDriverError(context.TODO(), "error reading record", nil) - }} + + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{ + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, + }, + err: dbsqlerrint.NewDriverError(context.TODO(), "error reading record", nil), + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) err := ars.loadBatchFor(0) assert.NotNil(t, err) @@ -650,45 +715,49 @@ func TestArrowRowScanner(t *testing.T) { var ars *arrowRowScanner = d.(*arrowRowScanner) - var callCount int - var lastReadBatch *sparkArrowBatch - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - callCount += 1 - lastReadBatch = &sparkArrowBatch - return fakeRecord{}, nil - }} + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{ + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, + }, + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) for _, i := range []int64{0, 1, 2, 3, 4} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, lastReadBatch) - assert.Equal(t, 1, callCount) - assert.Equal(t, int64(0), lastReadBatch.Start()) + assert.NotNil(t, fbl.lastReadBatch) + assert.Equal(t, 1, fbl.callCount) + assert.Equal(t, int64(0), fbl.lastReadBatch.Start()) } for _, i := range []int64{5, 6, 7} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, lastReadBatch) - assert.Equal(t, 2, callCount) - assert.Equal(t, int64(5), lastReadBatch.Start()) + assert.NotNil(t, fbl.lastReadBatch) + assert.Equal(t, 2, fbl.callCount) + assert.Equal(t, int64(5), fbl.lastReadBatch.Start()) } for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} { err := ars.loadBatchFor(i) assert.Nil(t, err) - assert.NotNil(t, lastReadBatch) - assert.Equal(t, 3, callCount) - assert.Equal(t, int64(8), lastReadBatch.Start()) + assert.NotNil(t, fbl.lastReadBatch) + assert.Equal(t, 3, fbl.callCount) + assert.Equal(t, int64(8), fbl.lastReadBatch.Start()) } err := ars.loadBatchFor(-1) assert.NotNil(t, err) - assert.EqualError(t, err, "databricks: driver error: "+errArrowRowsInvalidRowIndex(-1)) + assert.EqualError(t, err, "databricks: driver error: "+errArrowRowsInvalidRowNumber(-1)) err = ars.loadBatchFor(15) assert.NotNil(t, err) - assert.EqualError(t, err, "databricks: driver error: "+errArrowRowsInvalidRowIndex(15)) + assert.EqualError(t, err, "databricks: driver error: "+errArrowRowsInvalidRowNumber(15)) }) t.Run("Error on retrieving not implemented native arrow types", func(t *testing.T) { @@ -798,14 +867,25 @@ func TestArrowRowScanner(t *testing.T) { ars.UseArrowNativeComplexTypes = true ars.UseArrowNativeDecimal = true ars.UseArrowNativeIntervalTypes = true - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - return fakeRecord{}, nil - }} - ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner) dbsqlerr.DBError { - ars.columnValues = make([]columnValues, len(ars.arrowSchema.Fields())) + + fbl := &fakeBatchLoader{ + Delimiter: rowscanner.NewDelimiter(0, 15), + batches: []SparkArrowBatch{ + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(0, 5), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 5), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(5, 3), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(5, 3), Record: &fakeRecord{}}}}, + &sparkArrowBatch{Delimiter: rowscanner.NewDelimiter(8, 7), arrowRecords: []SparkArrowRecord{&sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(8, 7), Record: &fakeRecord{}}}}, + }, + } + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) + + ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { + columnValueHolders := make([]columnValues, len(ars.arrowSchema.Fields())) for i := range ars.arrowSchema.Fields() { - ars.columnValues[i] = &fakeColumnValues{} + columnValueHolders[i] = &fakeColumnValues{} } + ars.rowValues = NewRowValues(rowscanner.NewDelimiter(0, 0), columnValueHolders) return nil }} @@ -968,12 +1048,17 @@ func TestArrowRowScanner(t *testing.T) { ars := d.(*arrowRowScanner) assert.Equal(t, int64(53940), ars.NRows()) - var loadBatchCallCount int - rr := ars.recordReader - ars.recordReader = fakeRecordReader{fnNewRecord: func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - loadBatchCallCount += 1 - return rr.NewRecordFromBytes(arrowSchemaBytes, sparkArrowBatch) - }} + bi, ok := ars.batchIterator.(*batchIterator) + assert.True(t, ok) + bl := bi.batchLoader + fbl := &batchLoaderWrapper{ + Delimiter: rowscanner.NewDelimiter(bl.Start(), bl.Count()), + bl: bl, + } + + var e dbsqlerr.DBError + ars.batchIterator, e = NewBatchIterator(fbl) + assert.Nil(t, e) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) for i := int64(0); i < ars.NRows(); i = i + 1 { @@ -993,7 +1078,7 @@ func TestArrowRowScanner(t *testing.T) { } } - assert.Equal(t, 54, loadBatchCallCount) + assert.Equal(t, 54, fbl.callCount) }) t.Run("Retrieve values - native arrow schema", func(t *testing.T) { @@ -1039,6 +1124,9 @@ func TestArrowRowScanner(t *testing.T) { ars := d.(*arrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) + // err = ars.ScanRow(dest, 0) + // assert.Nil(t, err) + err = ars.ScanRow(dest, 1) assert.Nil(t, err) @@ -1471,15 +1559,58 @@ func (cv *fakeColumnValues) SetValueArray(colData arrow.ArrayData) error { return nil } -type fakeRecordReader struct { - fnNewRecord func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) +// type fakeRecordReader struct { +// fnNewRecord func(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) +// } + +// func (frr fakeRecordReader) NewRecordFromBytes(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { +// if frr.fnNewRecord != nil { +// return frr.fnNewRecord(arrowSchemaBytes, sparkArrowBatch) +// } +// return nil, nil +// } + +type fakeBatchLoader struct { + rowscanner.Delimiter + batches []SparkArrowBatch + callCount int + err dbsqlerr.DBError + lastReadBatch SparkArrowBatch } -func (frr fakeRecordReader) NewRecordFromBytes(arrowSchemaBytes []byte, sparkArrowBatch sparkArrowBatch) (arrow.Record, dbsqlerr.DBError) { - if frr.fnNewRecord != nil { - return frr.fnNewRecord(arrowSchemaBytes, sparkArrowBatch) +var _ BatchLoader = (*fakeBatchLoader)(nil) + +func (fbl *fakeBatchLoader) Close() {} +func (fbl *fakeBatchLoader) GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) { + fbl.callCount += 1 + if fbl.err != nil { + return nil, fbl.err } - return nil, nil + for i := range fbl.batches { + if fbl.batches[i].Contains(recordNum) { + fbl.lastReadBatch = fbl.batches[i] + return fbl.batches[i], nil + } + + } + return nil, dbsqlerrint.NewDriverError(context.Background(), errArrowRowsInvalidRowNumber(recordNum), nil) +} + +type batchLoaderWrapper struct { + rowscanner.Delimiter + bl BatchLoader + callCount int + lastLoadedBatch SparkArrowBatch +} + +var _ BatchLoader = (*batchLoaderWrapper)(nil) + +func (fbl *batchLoaderWrapper) Close() { fbl.bl.Close() } +func (fbl *batchLoaderWrapper) GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) { + fbl.callCount += 1 + batch, err := fbl.bl.GetBatchFor(recordNum) + fbl.lastLoadedBatch = batch + return batch, err } type fakeRecord struct { @@ -1782,14 +1913,14 @@ func getAllTypesSchema() *cli_service.TTableSchema { } type fakeValueContainerMaker struct { - fnMakeColumnValuesContainers func(ars *arrowRowScanner) dbsqlerr.DBError + fnMakeColumnValuesContainers func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError } var _ valueContainerMaker = (*fakeValueContainerMaker)(nil) -func (vcm *fakeValueContainerMaker) makeColumnValuesContainers(ars *arrowRowScanner) error { +func (vcm *fakeValueContainerMaker) makeColumnValuesContainers(ars *arrowRowScanner, d rowscanner.Delimiter) error { if vcm.fnMakeColumnValuesContainers != nil { - return vcm.fnMakeColumnValuesContainers(ars) + return vcm.fnMakeColumnValuesContainers(ars, d) } return nil diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index b534ce9..97acf46 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -1,7 +1,6 @@ package arrowbased import ( - "bufio" "bytes" "context" "io" @@ -14,7 +13,6 @@ import ( "net/http" - "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/ipc" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" @@ -22,25 +20,47 @@ import ( "github.com/databricks/databricks-sql-go/internal/fetcher" ) +type BatchIterator interface { + Next() (SparkArrowBatch, dbsqlerr.DBError) + Close() +} + type BatchLoader interface { - GetBatchFor(recordNum int64) (*sparkArrowBatch, dbsqlerr.DBError) + rowscanner.Delimiter + GetBatchFor(recordNum int64) (SparkArrowBatch, dbsqlerr.DBError) + Close() +} + +func NewBatchIterator(batchLoader BatchLoader) (BatchIterator, dbsqlerr.DBError) { + bi := &batchIterator{ + batchLoader: batchLoader, + } + + return bi, nil } -func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { +func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() } - inputChan := make(chan fetcher.FetchableItems[*sparkArrowBatch], len(files)) + inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(files)) + var rowCount int64 for i := range files { + f := files[i] li := &cloudURL{ - TSparkArrowResultLink: files[i], - minTimeToExpiry: cfg.MinTimeToExpiry, - useLz4Compression: cfg.UseLz4Compression, + // TSparkArrowResultLink: f, + Delimiter: rowscanner.NewDelimiter(f.StartRowOffset, f.RowCount), + fileLink: f.FileLink, + expiryTime: f.ExpiryTime, + minTimeToExpiry: cfg.MinTimeToExpiry, + compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, } inputChan <- li + + rowCount += f.RowCount } // make sure to close input channel or fetcher will block waiting for more inputs @@ -48,211 +68,244 @@ func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowRe f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) cbl := &batchLoader[*cloudURL]{ - Fetcher: f, - ctx: ctx, + Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), + fetcher: f, + ctx: ctx, } return cbl, nil } -func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { +func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { if cfg == nil { cfg = config.WithDefaults() } - var startRow int64 - inputChan := make(chan fetcher.FetchableItems[*sparkArrowBatch], len(batches)) + var startRow int64 = startRowOffset + var rowCount int64 + inputChan := make(chan fetcher.FetchableItems[SparkArrowBatch], len(batches)) for i := range batches { b := batches[i] if b != nil { li := &localBatch{ - TSparkArrowBatch: b, - startRow: startRow, - useLz4Compression: cfg.UseLz4Compression, + Delimiter: rowscanner.NewDelimiter(startRow, b.RowCount), + batchBytes: b.Batch, + arrowSchemaBytes: arrowSchemaBytes, + compressibleBatch: compressibleBatch{useLz4Compression: cfg.UseLz4Compression}, } inputChan <- li startRow = startRow + b.RowCount + rowCount += b.RowCount } } close(inputChan) f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) cbl := &batchLoader[*localBatch]{ - Fetcher: f, - ctx: ctx, + Delimiter: rowscanner.NewDelimiter(startRowOffset, rowCount), + fetcher: f, + ctx: ctx, } return cbl, nil } type batchLoader[T interface { - Fetch(ctx context.Context) ([]*sparkArrowBatch, error) + Fetch(ctx context.Context) (SparkArrowBatch, error) }] struct { - fetcher.Fetcher[*sparkArrowBatch] - arrowBatches []*sparkArrowBatch + rowscanner.Delimiter + fetcher fetcher.Fetcher[SparkArrowBatch] + arrowBatches []SparkArrowBatch ctx context.Context } var _ BatchLoader = (*batchLoader[*localBatch])(nil) -func (cbl *batchLoader[T]) GetBatchFor(recordNum int64) (*sparkArrowBatch, dbsqlerr.DBError) { +func (cbl *batchLoader[T]) GetBatchFor(rowNumber int64) (SparkArrowBatch, dbsqlerr.DBError) { for i := range cbl.arrowBatches { - if cbl.arrowBatches[i].Start() <= recordNum && cbl.arrowBatches[i].End() >= recordNum { + if cbl.arrowBatches[i].Contains(rowNumber) { return cbl.arrowBatches[i], nil } } - batchChan, _, err := cbl.Start() + batchChan, _, err := cbl.fetcher.Start() + var emptyBatch SparkArrowBatch if err != nil { - return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) + return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) } for { batch, ok := <-batchChan if !ok { - err := cbl.Err() + err := cbl.fetcher.Err() if err != nil { - return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) + return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) } break } cbl.arrowBatches = append(cbl.arrowBatches, batch) - if batch.Contains(recordNum) { + if batch.Contains(rowNumber) { return batch, nil } } - return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) + return emptyBatch, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowNumber(rowNumber), err) } -type cloudURL struct { - *cli_service.TSparkArrowResultLink - minTimeToExpiry time.Duration +func (cbl *batchLoader[T]) Close() { + for i := range cbl.arrowBatches { + cbl.arrowBatches[i].Close() + } +} + +type compressibleBatch struct { useLz4Compression bool } -func (cu *cloudURL) Fetch(ctx context.Context) ([]*sparkArrowBatch, error) { - if isLinkExpired(cu.ExpiryTime, cu.minTimeToExpiry) { - return nil, errors.New(dbsqlerr.ErrLinkExpired) +func (cb compressibleBatch) getReader(r io.Reader) io.Reader { + if cb.useLz4Compression { + return lz4.NewReader(r) + } + return r +} + +type cloudURL struct { + compressibleBatch + rowscanner.Delimiter + fileLink string + expiryTime int64 + minTimeToExpiry time.Duration +} + +func (cu *cloudURL) Fetch(ctx context.Context) (SparkArrowBatch, error) { + var sab SparkArrowBatch + + if isLinkExpired(cu.expiryTime, cu.minTimeToExpiry) { + return sab, errors.New(dbsqlerr.ErrLinkExpired) } - req, err := http.NewRequestWithContext(ctx, "GET", cu.FileLink, nil) + req, err := http.NewRequestWithContext(ctx, "GET", cu.fileLink, nil) if err != nil { - return nil, err + return sab, err } client := http.DefaultClient res, err := client.Do(req) if err != nil { - return nil, err + return sab, err } if res.StatusCode != http.StatusOK { - return nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) + return sab, dbsqlerrint.NewDriverError(ctx, errArrowRowsCloudFetchDownloadFailure, err) } defer res.Body.Close() - var arrowSchema *arrow.Schema - var arrowBatches []*sparkArrowBatch - - rdr, err := getArrowReader(res.Body, cu.useLz4Compression) + r := cu.compressibleBatch.getReader(res.Body) + records, err := getArrowRecords(r, cu.Start()) if err != nil { return nil, err } - startRow := cu.StartRowOffset + arrowBatch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(cu.Start(), cu.Count()), + arrowRecords: records, + } - for rdr.Next() { - r := rdr.Record() - r.Retain() - if arrowSchema == nil { - arrowSchema = r.Schema() - } + return &arrowBatch, nil +} - var output bytes.Buffer - w := ipc.NewWriter(&output, ipc.WithSchema(r.Schema())) +func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { + bufferSecs := int64(linkExpiryBuffer.Seconds()) + return expiryTime-bufferSecs < time.Now().Unix() +} - err := w.Write(r) - if err != nil { - panic(err) - } - err = w.Close() - if err != nil { - panic(err) - } +var _ fetcher.FetchableItems[SparkArrowBatch] = (*cloudURL)(nil) - recordBytes := output.Bytes() +type localBatch struct { + compressibleBatch + rowscanner.Delimiter + batchBytes []byte + arrowSchemaBytes []byte +} - arrowBatches = append(arrowBatches, &sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(startRow, r.NumRows()), - arrowRecordBytes: recordBytes, - hasSchema: true, - }) +var _ fetcher.FetchableItems[SparkArrowBatch] = (*localBatch)(nil) - startRow = startRow + r.NumRows() +func (lb *localBatch) Fetch(ctx context.Context) (SparkArrowBatch, error) { + r := lb.compressibleBatch.getReader(bytes.NewReader(lb.batchBytes)) + r = io.MultiReader(bytes.NewReader(lb.arrowSchemaBytes), r) - r.Release() + records, err := getArrowRecords(r, lb.Start()) + if err != nil { + return &sparkArrowBatch{}, err } - if rdr.Err() != nil { - panic(rdr.Err()) + lb.batchBytes = nil + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(lb.Start(), lb.Count()), + arrowRecords: records, } - rdr.Release() - return arrowBatches, nil + return &batch, nil } -func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { - bufferSecs := int64(linkExpiryBuffer.Seconds()) - return expiryTime-bufferSecs < time.Now().Unix() -} - -func getArrowReader(rd io.Reader, useLz4Compression bool) (*ipc.Reader, error) { - if useLz4Compression { - return ipc.NewReader(lz4.NewReader(rd)) +func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, error) { + ipcReader, err := ipc.NewReader(r) + if err != nil { + return nil, err } - return ipc.NewReader(bufio.NewReader(rd)) -} -func getArrowBatch(useLz4Compression bool, src []byte) ([]byte, error) { - if useLz4Compression { - srcBuffer := bytes.NewBuffer(src) - dstBuffer := bytes.NewBuffer(nil) + defer ipcReader.Release() + + startRow := startRowOffset + var records []SparkArrowRecord + for ipcReader.Next() { + r := ipcReader.Record() + r.Retain() - r := lz4.NewReader(srcBuffer) - _, err := io.Copy(dstBuffer, r) - if err != nil { - return nil, err + sar := sparkArrowRecord{ + Delimiter: rowscanner.NewDelimiter(startRow, r.NumRows()), + Record: r, } - return dstBuffer.Bytes(), nil + records = append(records, &sar) + + startRow += r.NumRows() } - return src, nil -} -var _ fetcher.FetchableItems[*sparkArrowBatch] = (*cloudURL)(nil) + if ipcReader.Err() != nil { + for i := range records { + records[i].Release() + } + return nil, ipcReader.Err() + } -type localBatch struct { - *cli_service.TSparkArrowBatch - startRow int64 - useLz4Compression bool + return records, nil +} + +type batchIterator struct { + nextBatchStart int64 + batchLoader BatchLoader } -var _ fetcher.FetchableItems[*sparkArrowBatch] = (*localBatch)(nil) +var _ BatchIterator = (*batchIterator)(nil) -func (lb *localBatch) Fetch(ctx context.Context) ([]*sparkArrowBatch, error) { - arrowBatchBytes, err := getArrowBatch(lb.useLz4Compression, lb.Batch) - if err != nil { - return nil, err - } - batch := &sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(lb.startRow, lb.RowCount), - arrowRecordBytes: arrowBatchBytes, +func (bi *batchIterator) Next() (SparkArrowBatch, dbsqlerr.DBError) { + if bi != nil && bi.batchLoader != nil { + batch, err := bi.batchLoader.GetBatchFor(bi.nextBatchStart) + if batch != nil && err == nil { + bi.nextBatchStart = batch.Start() + batch.Count() + } + return batch, err } + return nil, nil +} - return []*sparkArrowBatch{batch}, nil +func (bi *batchIterator) Close() { + if bi != nil && bi.batchLoader != nil { + bi.batchLoader.Close() + } } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index fe412fc..35bad33 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -15,44 +15,14 @@ import ( "github.com/apache/arrow/go/v12/arrow/ipc" "github.com/apache/arrow/go/v12/arrow/memory" dbsqlerr "github.com/databricks/databricks-sql-go/errors" - "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) -func generateMockArrowBytes() []byte { - mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) - defer mem.AssertSize(nil, 0) - - fields := []arrow.Field{ - {Name: "id", Type: arrow.PrimitiveTypes.Int32}, - {Name: "name", Type: arrow.BinaryTypes.String}, - } - schema := arrow.NewSchema(fields, nil) - - builder := array.NewRecordBuilder(mem, schema) - defer builder.Release() - - builder.Field(0).(*array.Int32Builder).AppendValues([]int32{1, 2, 3}, nil) - builder.Field(1).(*array.StringBuilder).AppendValues([]string{"one", "two", "three"}, nil) - - record := builder.NewRecord() - defer record.Release() - - var buf bytes.Buffer - w := ipc.NewWriter(&buf, ipc.WithSchema(record.Schema())) - if err := w.Write(record); err != nil { - return nil - } - if err := w.Close(); err != nil { - return nil - } - return buf.Bytes() -} +func TestCloudURLFetch(t *testing.T) { -func TestBatchLoader(t *testing.T) { var handler func(w http.ResponseWriter, r *http.Request) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler(w, r) @@ -62,24 +32,24 @@ func TestBatchLoader(t *testing.T) { name string response func(w http.ResponseWriter, r *http.Request) linkExpired bool - expectedResponse []*sparkArrowBatch + expectedResponse SparkArrowBatch expectedErr error }{ { name: "cloud-fetch-happy-case", response: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes()) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) if err != nil { panic(err) } }, linkExpired: false, - expectedResponse: []*sparkArrowBatch{ - { - Delimiter: rowscanner.NewDelimiter(0, 3), - arrowRecordBytes: generateMockArrowBytes(), - hasSchema: true, + expectedResponse: &sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(0, 3), + arrowRecords: []SparkArrowRecord{ + &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(0, 3), Record: generateArrowRecord()}, + &sparkArrowRecord{Delimiter: rowscanner.NewDelimiter(3, 3), Record: generateArrowRecord()}, }, }, expectedErr: nil, @@ -88,7 +58,7 @@ func TestBatchLoader(t *testing.T) { name: "cloud-fetch-expired_link", response: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - _, err := w.Write(generateMockArrowBytes()) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) if err != nil { panic(err) } @@ -117,26 +87,80 @@ func TestBatchLoader(t *testing.T) { if tc.linkExpired { expiryTime = expiryTime.Add(-1 * time.Second) } else { - expiryTime = expiryTime.Add(1 * time.Second) + expiryTime = expiryTime.Add(10 * time.Second) } cu := &cloudURL{ - TSparkArrowResultLink: &cli_service.TSparkArrowResultLink{ - FileLink: server.URL, - ExpiryTime: expiryTime.Unix(), - }, + Delimiter: rowscanner.NewDelimiter(0, 3), + fileLink: server.URL, + expiryTime: expiryTime.Unix(), } ctx := context.Background() resp, err := cu.Fetch(ctx) - if !reflect.DeepEqual(resp, tc.expectedResponse) { - t.Errorf("expected (%v), got (%v)", tc.expectedResponse, resp) + if tc.expectedResponse != nil { + assert.NotNil(t, resp) + esab, ok := tc.expectedResponse.(*sparkArrowBatch) + assert.True(t, ok) + asab, ok2 := resp.(*sparkArrowBatch) + assert.True(t, ok2) + if !reflect.DeepEqual(esab.Delimiter, asab.Delimiter) { + t.Errorf("expected (%v), got (%v)", esab.Delimiter, asab.Delimiter) + } + assert.Equal(t, len(esab.arrowRecords), len(asab.arrowRecords)) + for i := range esab.arrowRecords { + er := esab.arrowRecords[i] + ar := asab.arrowRecords[i] + + eb := generateMockArrowBytes(er) + ab := generateMockArrowBytes(ar) + assert.Equal(t, eb, ab) + } } + if !errors.Is(err, tc.expectedErr) { assert.EqualErrorf(t, err, fmt.Sprintf("%v", tc.expectedErr), "expected (%v), got (%v)", tc.expectedErr, err) } }) } } + +func generateArrowRecord() arrow.Record { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + + fields := []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32}, + {Name: "name", Type: arrow.BinaryTypes.String}, + } + schema := arrow.NewSchema(fields, nil) + + builder := array.NewRecordBuilder(mem, schema) + defer builder.Release() + + builder.Field(0).(*array.Int32Builder).AppendValues([]int32{1, 2, 3}, nil) + builder.Field(1).(*array.StringBuilder).AppendValues([]string{"one", "two", "three"}, nil) + + record := builder.NewRecord() + + return record +} + +func generateMockArrowBytes(record arrow.Record) []byte { + + defer record.Release() + + var buf bytes.Buffer + w := ipc.NewWriter(&buf, ipc.WithSchema(record.Schema())) + if err := w.Write(record); err != nil { + return nil + } + if err := w.Write(record); err != nil { + return nil + } + if err := w.Close(); err != nil { + return nil + } + return buf.Bytes() +} diff --git a/internal/rows/arrowbased/chunkedByteReader.go b/internal/rows/arrowbased/chunkedByteReader.go deleted file mode 100644 index e896ee8..0000000 --- a/internal/rows/arrowbased/chunkedByteReader.go +++ /dev/null @@ -1,98 +0,0 @@ -package arrowbased - -import ( - "fmt" - "io" - - dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors" - "github.com/pkg/errors" -) - -var errChunkedByteReaderInvalidState1 = "databricks: chunkedByteReader invalid state chunkIndex:%d, byteIndex:%d" -var errChunkedByteReaderOverreadOfNonterminalChunk = "databricks: chunkedByteReader invalid state chunks:%d chunkIndex:%d len:%d byteIndex%d" - -// chunkedByteReader implements the io.Reader interface on a collection -// of byte arrays. -// The TSparkArrowBatch instances returned in TFetchResultsResp contain -// a byte array containing an ipc formatted MessageRecordBatch message. -// The ipc reader expects a bytestream containing a MessageSchema message -// followed by a MessageRecordBatch message. -// chunkedByteReader is used to avoid allocating new byte arrays for each -// TSparkRecordBatch and copying the schema and record bytes. -type chunkedByteReader struct { - // byte slices to be read as a single slice - chunks [][]byte - // index of the chunk being read from - chunkIndex int - // index in the current chunk - byteIndex int -} - -// Read reads up to len(p) bytes into p. It returns the number of bytes -// read (0 <= n <= len(p)) and any error encountered. -// -// When Read encounters an error or end-of-file condition after -// successfully reading n > 0 bytes, it returns the number of -// bytes read and a non-nil error. If len(p) is zero Read will -// return 0, nil -// -// Callers should always process the n > 0 bytes returned before -// considering the error err. Doing so correctly handles I/O errors -// that happen after reading some bytes. -func (c *chunkedByteReader) Read(p []byte) (bytesRead int, err error) { - err = c.isValid() - - for err == nil && bytesRead < len(p) && !c.isEOF() { - chunk := c.chunks[c.chunkIndex] - chunkLen := len(chunk) - source := chunk[c.byteIndex:] - - n := copy(p[bytesRead:], source) - bytesRead += n - - c.byteIndex += n - - if c.byteIndex >= chunkLen { - c.byteIndex = 0 - c.chunkIndex += 1 - } - } - - if err != nil { - err = dbsqlerr.WrapErr(err, "datbricks: read failure in chunked byte reader") - } else if c.isEOF() { - err = io.EOF - } - - return bytesRead, err -} - -// isEOF returns true if the chunkedByteReader is in -// an end-of-file condition. -func (c *chunkedByteReader) isEOF() bool { - return c.chunkIndex >= len(c.chunks) -} - -// reset returns the chunkedByteReader to its initial state -func (c *chunkedByteReader) reset() { - c.byteIndex = 0 - c.chunkIndex = 0 -} - -// verify that the chunkedByteReader is in a valid state -func (c *chunkedByteReader) isValid() error { - if c == nil { - return errors.New("call to Read on nil chunkedByteReader") - } - if c.byteIndex < 0 || c.chunkIndex < 0 { - return errors.New(fmt.Sprintf(errChunkedByteReaderInvalidState1, c.chunkIndex, c.byteIndex)) - } - - if c.chunkIndex < len(c.chunks)-1 { - chunkLen := len(c.chunks[c.chunkIndex]) - if 0 < chunkLen && c.byteIndex >= chunkLen { - return errors.New(fmt.Sprintf(errChunkedByteReaderOverreadOfNonterminalChunk, len(c.chunks), c.chunkIndex, len(c.chunks[c.chunkIndex]), c.byteIndex)) - } - } - return nil -} diff --git a/internal/rows/arrowbased/chunkedByteReader_test.go b/internal/rows/arrowbased/chunkedByteReader_test.go deleted file mode 100644 index 1cda921..0000000 --- a/internal/rows/arrowbased/chunkedByteReader_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package arrowbased - -import ( - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestChunkedByteReader(t *testing.T) { - c := chunkedByteReader{} - nbytes, err := c.Read(nil) - assert.Equal(t, io.EOF, err) - assert.Zero(t, nbytes) - - chunkSets := [][][]byte{ - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}, - {{1, 2, 3, 4, 5, 6}, {7, 8}, {9}, {10, 11, 12}}, - } - - for i := range chunkSets { - - c.chunks = chunkSets[i] - buf := make([]byte, 10) - testReadingChunks(t, c, buf, []int{10, 2}) - - c.reset() - buf = make([]byte, 3) - testReadingChunks(t, c, buf, []int{3, 3, 3, 3}) - - c.reset() - buf = make([]byte, 2) - testReadingChunks(t, c, buf, []int{2, 2, 2, 2, 2, 2}) - - c.reset() - buf = make([]byte, 20) - testReadingChunks(t, c, buf, []int{12}) - - c.reset() - buf = make([]byte, 1) - testReadingChunks(t, c, buf, []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) - - c.reset() - buf = make([]byte, 5) - testReadingChunks(t, c, buf, []int{5, 5, 2}) - } -} - -func testReadingChunks(t *testing.T, c chunkedByteReader, target []byte, readSizes []int) { - - for i, expectedSize := range readSizes { - n, err := c.Read(target) - assert.Equal(t, expectedSize, n) - if i == len(readSizes)-1 { - assert.Equal(t, io.EOF, err) - } else { - assert.Nil(t, err) - } - } - -} diff --git a/internal/rows/arrowbased/columnValues.go b/internal/rows/arrowbased/columnValues.go index 34287ec..0b6fc7d 100644 --- a/internal/rows/arrowbased/columnValues.go +++ b/internal/rows/arrowbased/columnValues.go @@ -12,6 +12,71 @@ import ( "github.com/pkg/errors" ) +// Abstraction for holding the values for a set of rows +type RowValues interface { + rowscanner.Delimiter + Close() + NColumns() int + SetColumnValues(columnIndex int, values arrow.ArrayData) error + IsNull(columnIndex int, rowNumber int64) bool + Value(columnIndex int, rowNumber int64) (any, error) + SetDelimiter(d rowscanner.Delimiter) +} + +func NewRowValues(d rowscanner.Delimiter, holders []columnValues) RowValues { + return &rowValues{Delimiter: d, columnValueHolders: holders} +} + +type rowValues struct { + rowscanner.Delimiter + columnValueHolders []columnValues +} + +var _ RowValues = (*rowValues)(nil) + +func (rv *rowValues) Close() { + // release any retained arrow arrays + for i := range rv.columnValueHolders { + if rv.columnValueHolders[i] != nil { + rv.columnValueHolders[i].Release() + } + } +} + +func (rv *rowValues) SetColumnValues(columnIndex int, values arrow.ArrayData) error { + var err error + if columnIndex < len(rv.columnValueHolders) && rv.columnValueHolders[columnIndex] != nil { + rv.columnValueHolders[columnIndex].Release() + err = rv.columnValueHolders[columnIndex].SetValueArray(values) + } + return err +} + +func (rv *rowValues) IsNull(columnIndex int, rowNumber int64) bool { + var b bool = true + if columnIndex < len(rv.columnValueHolders) { + b = rv.columnValueHolders[columnIndex].IsNull(int(rowNumber - rv.Start())) + } + return b +} + +func (rv *rowValues) Value(columnIndex int, rowNumber int64) (any, error) { + var err error + var value any + if columnIndex < len(rv.columnValueHolders) { + value, err = rv.columnValueHolders[columnIndex].Value(int(rowNumber - rv.Start())) + } + return value, err +} + +func (rv *rowValues) NColumns() int { return len(rv.columnValueHolders) } + +func (rv *rowValues) SetDelimiter(d rowscanner.Delimiter) { rv.Delimiter = d } + +type valueContainerMaker interface { + makeColumnValuesContainers(ars *arrowRowScanner, d rowscanner.Delimiter) error +} + // columnValues is the interface for accessing the values for a column type columnValues interface { Value(int) (any, error) diff --git a/internal/rows/arrowbased/errors.go b/internal/rows/arrowbased/errors.go index 9726f00..d3d69d3 100644 --- a/internal/rows/arrowbased/errors.go +++ b/internal/rows/arrowbased/errors.go @@ -22,8 +22,8 @@ func errArrowRowsUnsupportedNativeType(t string) string { func errArrowRowsUnsupportedWithHiveSchema(t string) string { return fmt.Sprintf("databricks: arrow native values for %s require arrow schema", t) } -func errArrowRowsInvalidRowIndex(index int64) string { - return fmt.Sprintf("databricks: row index %d is not contained in any arrow batch", index) +func errArrowRowsInvalidRowNumber(index int64) string { + return fmt.Sprintf("databricks: row number %d is not contained in any arrow batch", index) } func errArrowRowsUnableToCreateDecimalType(scale, precision int32) string { return fmt.Sprintf("databricks: unable to create decimal type scale: %d, precision: %d", scale, precision) diff --git a/internal/rows/columnbased/columnRows.go b/internal/rows/columnbased/columnRows.go index 21ff7e4..6d9dbdf 100644 --- a/internal/rows/columnbased/columnRows.go +++ b/internal/rows/columnbased/columnRows.go @@ -75,8 +75,9 @@ func (crs *columnRowScanner) NRows() int64 { // a buffer held in dest. func (crs *columnRowScanner) ScanRow( dest []driver.Value, - rowIndex int64) dbsqlerr.DBError { + rowNumber int64) dbsqlerr.DBError { + rowIndex := rowNumber - crs.Start() // populate the destinatino slice for i := range dest { val, err := crs.value(crs.rowSet.Columns[i], crs.schema.Columns[i], rowIndex) diff --git a/internal/rows/errors.go b/internal/rows/errors.go index 4e84d8f..19d2782 100644 --- a/internal/rows/errors.go +++ b/internal/rows/errors.go @@ -7,6 +7,8 @@ var errRowsNilRows = "databricks: nil Rows instance" var errRowsUnknowRowType = "databricks: unknown rows representation" var errRowsCloseFailed = "databricks: Rows instance Close operation failed" var errRowsMetadataFetchFailed = "databricks: Rows instance failed to retrieve result set metadata" +var errRowsOnlyForward = "databricks: Rows instance can only iterate forward over rows" +var errInvalidRowNumberState = "databricks: row number is in an invalid state" func errRowsInvalidColumnIndex(index int) string { return fmt.Sprintf("databricks: invalid column index: %d", index) diff --git a/internal/rows/rows.go b/internal/rows/rows.go index 797b7a0..3555f9a 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -4,7 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" + "io" "math" "reflect" "time" @@ -238,7 +238,7 @@ func (r *rows) Next(dest []driver.Value) error { } // Put values into the destination slice - err = r.ScanRow(dest, r.nextRowIndex()) + err = r.ScanRow(dest, r.nextRowNumber) if err != nil { return err } @@ -248,8 +248,6 @@ func (r *rows) Next(dest []driver.Value) error { return nil } -func (r *rows) nextRowIndex() int64 { return r.nextRowNumber - r.RowScanner.Start() } - // ColumnTypeScanType returns column's native type func (r *rows) ColumnTypeScanType(index int) reflect.Type { err := isValidRows(r) @@ -453,12 +451,28 @@ func (r *rows) fetchResultPage() error { } if r.RowScanner != nil && r.nextRowNumber < r.RowScanner.Start() { - //TODO - return errors.New("can't go backward") + return dbsqlerr_int.NewDriverError(r.ctx, errRowsOnlyForward, nil) } if r.ResultPageIterator == nil { - r.ResultPageIterator = makeResultPageIterator(r) + var d rowscanner.Delimiter + if r.RowScanner != nil { + d = rowscanner.NewDelimiter(r.RowScanner.Start(), r.RowScanner.Count()) + } else { + d = rowscanner.NewDelimiter(0, 0) + } + r.ResultPageIterator = makeResultPageIterator(r, d) + } + + // Close/release the existing row scanner before loading the next result page to + // help keep memory usage down. + if r.RowScanner != nil { + r.RowScanner.Close() + r.RowScanner = nil + } + + if !r.ResultPageIterator.HasNext() { + return io.EOF } fetchResult, err1 := r.ResultPageIterator.Next() @@ -471,9 +485,10 @@ func (r *rows) fetchResultPage() error { return err1 } + // We should be iterating over the rows so the next row number should be in the + // next result page if !r.RowScanner.Contains(r.nextRowNumber) { - // TODO - return errors.New("Invalid row number state") + return dbsqlerr_int.NewDriverError(r.ctx, errInvalidRowNumberState, nil) } return nil @@ -512,7 +527,12 @@ func (r *rows) makeRowScanner(fetchResults *cli_service.TFetchResultsResp) dbsql err = dbsqlerr_int.NewDriverError(r.ctx, errRowsUnknowRowType, nil) } + if r.RowScanner != nil { + r.RowScanner.Close() + } + r.RowScanner = rs + if fetchResults.HasMoreRows != nil { r.hasMoreRows = *fetchResults.HasMoreRows } else { @@ -533,14 +553,7 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger { return r.logger_ } -func makeResultPageIterator(r *rows) rowscanner.ResultPageIterator { - - var d rowscanner.Delimiter - if r.RowScanner != nil { - d = rowscanner.NewDelimiter(r.RowScanner.Start(), r.RowScanner.Count()) - } else { - d = rowscanner.NewDelimiter(0, 0) - } +func makeResultPageIterator(r *rows, d rowscanner.Delimiter) rowscanner.ResultPageIterator { resultPageIterator := rowscanner.NewResultPageIterator( d, diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index fc27506..4ec1604 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -220,7 +220,6 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 1, - nextRowIndex: i64Zero, nextRowNumber: i64Zero, offset: i64Zero, }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -231,7 +230,6 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 1, - nextRowIndex: int64(4), nextRowNumber: int64(4), offset: i64Zero, }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -242,7 +240,6 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 2, - nextRowIndex: int64(1), nextRowNumber: int64(6), offset: int64(5), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -250,20 +247,18 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { // next row number is two, can't fetch previous result page rowSet.nextRowNumber = 2 err = rowSet.fetchResultPage() - assert.EqualError(t, err, "can't go backward") + assert.ErrorContains(t, err, errRowsOnlyForward) // next row number is past end of next result page rowSet.nextRowNumber = 15 err = rowSet.fetchResultPage() - assert.EqualError(t, err, "Invalid row number state") + assert.ErrorContains(t, err, errInvalidRowNumberState) rowSet.nextRowNumber = 12 err = rowSet.fetchResultPage() - rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 3, - nextRowIndex: int64(2), nextRowNumber: int64(12), offset: int64(10), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -273,22 +268,15 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { errMsg := io.EOF.Error() assert.EqualError(t, err, errMsg) - // next row number is before start of results, should fetch all result pages - // going forward and then return EOF + // Once we've hit an EOF state any call to fetchResultPage will return EOF rowSet.nextRowNumber = -1 err = rowSet.fetchResultPage() - assert.EqualError(t, err, "can't go backward") + assert.EqualError(t, err, io.EOF.Error()) // jump back to last page rowSet.nextRowNumber = 13 err = rowSet.fetchResultPage() - rowTestPagingResult{ - getMetadataCount: 1, - fetchResultsCount: 3, - nextRowIndex: int64(3), - nextRowNumber: int64(13), - offset: int64(10), - }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) + assert.EqualError(t, err, io.EOF.Error()) } func TestRowsFetchResultPageWithDirectResults(t *testing.T) { @@ -315,7 +303,6 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 1, - nextRowIndex: i64Zero, nextRowNumber: i64Zero, offset: i64Zero, }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -326,7 +313,6 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 1, - nextRowIndex: int64(4), nextRowNumber: int64(4), offset: i64Zero, }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -337,28 +323,26 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 2, - nextRowIndex: int64(1), nextRowNumber: int64(6), offset: int64(5), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) - // next row number is two, should fetch previous result page + // next row number is two, can't fetch previous result page rowSet.nextRowNumber = 2 err = rowSet.fetchResultPage() - assert.EqualError(t, err, "can't go backward") + assert.ErrorContains(t, err, errRowsOnlyForward) // next row number is past end of results, should fetch all result pages // going forward and then return EOF rowSet.nextRowNumber = 15 err = rowSet.fetchResultPage() - assert.EqualError(t, err, "Invalid row number state") + assert.ErrorContains(t, err, errInvalidRowNumberState) rowSet.nextRowNumber = 10 err = rowSet.fetchResultPage() rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 3, - nextRowIndex: int64(0), nextRowNumber: int64(10), offset: int64(10), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) @@ -366,19 +350,13 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { rowSet.nextRowNumber = 15 err = rowSet.fetchResultPage() errMsg := io.EOF.Error() - assert.EqualError(t, err, errMsg) + assert.ErrorContains(t, err, errMsg) // jump back to last page rowSet.nextRowNumber = 12 err = rowSet.fetchResultPage() + assert.ErrorContains(t, err, errMsg) - rowTestPagingResult{ - getMetadataCount: 1, - fetchResultsCount: 3, - nextRowIndex: int64(2), - nextRowNumber: int64(12), - offset: int64(10), - }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) } var rowTestColNames []string = []string{ @@ -481,7 +459,6 @@ func TestNextNoDirectResults(t *testing.T) { assert.Nil(t, err) assert.Equal(t, row0, row) assert.Equal(t, int64(1), rowSet.nextRowNumber) - assert.Equal(t, int64(1), rowSet.nextRowIndex()) assert.Equal(t, 1, getMetadataCount) assert.Equal(t, 1, fetchResultsCount) } @@ -537,7 +514,6 @@ func TestNextWithDirectResults(t *testing.T) { assert.Nil(t, err) assert.Equal(t, row0, row) assert.Equal(t, int64(1), rowSet.nextRowNumber) - assert.Equal(t, int64(1), rowSet.nextRowIndex()) assert.Equal(t, 2, getMetadataCount) assert.Equal(t, 1, fetchResultsCount) } @@ -760,7 +736,6 @@ func TestFetchResultsWithRetries(t *testing.T) { type rowTestPagingResult struct { getMetadataCount int fetchResultsCount int - nextRowIndex int64 nextRowNumber int64 offset int64 errMessage *string @@ -769,7 +744,6 @@ type rowTestPagingResult struct { func (rt rowTestPagingResult) validatePaging(t *testing.T, rowSet *rows, err error, fetchResultsCount, getMetadataCount int) { assert.Equal(t, rt.fetchResultsCount, fetchResultsCount) assert.Equal(t, rt.getMetadataCount, getMetadataCount) - assert.Equal(t, rt.nextRowIndex, rowSet.nextRowIndex()) assert.Equal(t, rt.nextRowNumber, rowSet.nextRowNumber) assert.Equal(t, rt.offset, rowSet.RowScanner.Start()) if rt.errMessage == nil { diff --git a/internal/rows/rowscanner/rowScanner.go b/internal/rows/rowscanner/rowScanner.go index 0e58b5b..4886a9e 100644 --- a/internal/rows/rowscanner/rowScanner.go +++ b/internal/rows/rowscanner/rowScanner.go @@ -20,7 +20,7 @@ type RowScanner interface { // The dest should not be written to outside of ScanRow. Care // should be taken when closing a RowScanner not to modify // a buffer held in dest. - ScanRow(dest []driver.Value, rowIndex int64) dbsqlerr.DBError + ScanRow(dest []driver.Value, rowNumber int64) dbsqlerr.DBError // NRows returns the number of rows in the current result page NRows() int64