diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index b301722667f..11cccf295c1 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -336,53 +336,76 @@ func (ins *Insert) processGenerate(vcursor VCursor, bindVars map[string]*querypb // If generation is needed, generate the requested number of values (as one call). if count != 0 { - rss, _, err := vcursor.ResolveDestinations(ins.Generate.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) - if err != nil { - return 0, err - } - // TODO: place where to decide routing maybe for snowflake - if len(rss) != 1 { - return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "auto sequence generation can happen through single shard only, it is getting routed to %d shards", len(rss)) - } - bindVars := map[string]*querypb.BindVariable{"n": sqltypes.Int64BindVariable(count)} - qr, err := vcursor.ExecuteStandalone(ins.Generate.Query, bindVars, rss[0]) - if err != nil { - return 0, err - } - // If no rows are returned, it's an internal error, and the code - // must panic, which will be caught and reported. - insertID, err = evalengine.ToInt64(qr.Rows[0][0]) - if err != nil { - return 0, err + if ins.Generate.Type == vindexes.TypeSnowflake { + // We will pick any shard + rss, _, err := vcursor.ResolveDestinations(ins.Generate.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) + if err != nil { + return 0, err + } + if len(rss) != 1 { + return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "auto snowflake generation can happen with at least one shard this keyspace: %s", ins.Generate.Keyspace.Name) + } + bindVars := map[string]*querypb.BindVariable{"n": sqltypes.Int64BindVariable(count)} + qr, err := vcursor.ExecuteStandalone(ins.Generate.Query, bindVars, rss[0]) + if err != nil { + return 0, err + } + // If no rows are returned, it's an internal error, and the code + // must panic, which will be caught and reported. + insertID, err = evalengine.ToInt64(qr.Rows[0][0]) + if err != nil { + return 0, err + } + } else { + rss, _, err := vcursor.ResolveDestinations(ins.Generate.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}}) + if err != nil { + return 0, err + } + if len(rss) != 1 { + return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "auto sequence generation can happen through single shard only, it is getting routed to %d shards", len(rss)) + } + bindVars := map[string]*querypb.BindVariable{"n": sqltypes.Int64BindVariable(count)} + qr, err := vcursor.ExecuteStandalone(ins.Generate.Query, bindVars, rss[0]) + if err != nil { + return 0, err + } + // If no rows are returned, it's an internal error, and the code + // must panic, which will be caught and reported. + insertID, err = evalengine.ToInt64(qr.Rows[0][0]) + if err != nil { + return 0, err + } } } - // Fill the holes where no value was supplied. - // For Snowflake + // Fill the holes where no value was supplied depending on the type of sequence: snowflake or sequence + cur := insertID + // for Snowflake if ins.Generate.Type == vindexes.TypeSnowflake { - cur := insertID + var totalInc int64 ts := (cur >> int64(SequenceLength+MachineIDLength)) + SnowflakeStartTime.UTC().UnixNano()/1e6 sequence := cur & int64(MaxSequence) machineID := (cur & (int64(MaxMachineID) << SequenceLength)) >> SequenceLength for i, v := range resolved { - fmt.Println(fmt.Sprintf("Generating Snowflake, %s id %d", ins.GetTableName(), cur)) if shouldGenerate(v) { bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.Int64BindVariable(cur) // calculate next id and advance ts and sequence - totalInc := sequence + 1 - ts := ts + totalInc/MaxSequence + totalInc = sequence + 1 + ts = ts + totalInc/MaxSequence sequence = totalInc % MaxSequence // TODO: generate next id properly for snowflake df := elapsedTime(ts, SnowflakeStartTime) - cur = int64((uint64(df) << uint64(timestampMoveLength)) | (uint64(machineID) << uint64(machineIDMoveLength)) | uint64(sequence)) + cur = (df << timestampMoveLength) | (machineID << machineIDMoveLength) | sequence + fmt.Println( + ((df << timestampMoveLength) | (machineID << machineIDMoveLength) | sequence), + ) } else { bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.ValueBindVariable(v) } } return insertID, nil } - // For Sequence - cur := insertID + // for Sequence for i, v := range resolved { if shouldGenerate(v) { bindVars[SeqVarName+strconv.Itoa(i)] = sqltypes.Int64BindVariable(cur) diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 0a5c4a7fb01..8c40f0ae291 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -29,6 +29,116 @@ import ( vschemapb "vitess.io/vitess/go/vt/proto/vschema" ) +func TestInsertShardedSnoflakeGenerate(t *testing.T) { + invschema := &vschemapb.SrvVSchema{ + Keyspaces: map[string]*vschemapb.Keyspace{ + "sharded": { + Sharded: true, + Vindexes: map[string]*vschemapb.Vindex{ + "hash": { + Type: "hash", + }, + }, + Tables: map[string]*vschemapb.Table{ + "t1": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Name: "hash", + Columns: []string{"id"}, + }}, + }, + }, + }, + }, + } + vs := vindexes.BuildVSchema(invschema) + ks := vs.Keyspaces["sharded"] + + ins := NewInsert( + InsertSharded, + ks.Keyspace, + []sqltypes.PlanValue{{ + // colVindex columns: id + Values: []sqltypes.PlanValue{{ + // 5 rows. + Values: []sqltypes.PlanValue{{ + Value: sqltypes.NewInt64(1), + }, { + Value: sqltypes.NewInt64(2), + }, { + Value: sqltypes.NewInt64(3), + }, { + Value: sqltypes.NewInt64(4), + }, { + Value: sqltypes.NewInt64(5), + }}, + }}, + }}, + ks.Tables["t1"], + "prefix", + []string{" mid1", " mid2", " mid3", " mid4", " mid5"}, + " suffix", + ) + + ins.Generate = &Generate{ + Keyspace: &vindexes.Keyspace{ + Name: "ks2", + Sharded: true, + }, + Type: "snowflake", + Query: "dummy_snowflake_generate", + Values: sqltypes.PlanValue{ + Values: []sqltypes.PlanValue{ + {Value: sqltypes.NewInt64(1)}, + {Value: sqltypes.NULL}, + {Value: sqltypes.NewInt64(2)}, + {Value: sqltypes.NULL}, + {Value: sqltypes.NULL}, + }, + }, + } + + vc := newDMLTestVCursor("-20", "20-") + vc.shardForKsid = []string{"20-", "-20", "20-", "-20", "-20"} + + // Snowflake ids + // | 2124054243676528637 | 1732771994712 | 2024-11-28 05:33:14.7120 | 4093 | 1 | + // | 2124054243676528638 | 1732771994712 | 2024-11-28 05:33:14.7120 | 4094 | 1 | + // | 2124054243680718848 | 1732771994713 | 2024-11-28 05:33:14.7130 | 0 | 1 | + vc.results = []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "nextval", + "int64", + ), + "2124054243676528637", + ), + {InsertID: 1}, + } + + result, err := ins.Execute(vc, map[string]*querypb.BindVariable{}, false) + if err != nil { + t.Fatal(err) + } + vc.ExpectLog(t, []string{ + `ResolveDestinations ks2 [] Destinations:DestinationAnyShard()`, + `ExecuteStandalone dummy_snowflake_generate n: type:INT64 value:"3" ks2 -20`, + // Based on shardForKsid, values returned will be 20-, -20, 20-. + `ResolveDestinations sharded [value:"0" value:"1" value:"2" value:"3" value:"4"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6),DestinationKeyspaceID(06e7ea22ce92708f),DestinationKeyspaceID(4eb190c9a2fa169c),DestinationKeyspaceID(d2fd8867d50d2dfe),DestinationKeyspaceID(70bb023c810ca87a)`, + // Row 2 will go to -20, rows 1 & 3 will go to 20- + `ExecuteMultiShard ` + + `sharded.20-: prefix mid1, mid3 suffix ` + + `{__seq0: type:INT64 value:"1" __seq1: type:INT64 value:"2124054243676528637" __seq2: type:INT64 value:"2" __seq3: type:INT64 value:"2124054243676528638" __seq4: type:INT64 value:"2124054243680718848" ` + + `_id_0: type:INT64 value:"1" _id_1: type:INT64 value:"2" _id_2: type:INT64 value:"3" _id_3: type:INT64 value:"4" _id_4: type:INT64 value:"5"} ` + + `sharded.-20: prefix mid2, mid4, mid5 suffix ` + + `{__seq0: type:INT64 value:"1" __seq1: type:INT64 value:"2124054243676528637" __seq2: type:INT64 value:"2" __seq3: type:INT64 value:"2124054243676528638" __seq4: type:INT64 value:"2124054243680718848" ` + + `_id_0: type:INT64 value:"1" _id_1: type:INT64 value:"2" _id_2: type:INT64 value:"3" _id_3: type:INT64 value:"4" _id_4: type:INT64 value:"5"} ` + + `true false`, + }) + + // The insert id returned by ExecuteMultiShard should be overwritten by processGenerate. + expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 2124054243676528637}) +} + func TestInsertUnshardedSnowflakeGenerate(t *testing.T) { ins := NewQueryInsert( InsertUnsharded, diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index f95e3376b09..88b4f4388fb 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -255,9 +255,7 @@ func buildTables(ks *vschemapb.Keyspace, vschema *VSchema, ksvschema *KeyspaceSc } t.Type = table.Type case TypeSnowflake: - if keyspace.Sharded && table.Pinned == "" { - return fmt.Errorf("snowflake table has to be in an unsharded keyspace or must be pinned: %s", tname) - } + // Snowflake should be ok to use with multiple shards t.Type = table.Type default: return fmt.Errorf("unidentified table type %s", table.Type) diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 58b7af51cb1..ce79ea23204 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -537,7 +537,6 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { return nil, err } tableName := qre.plan.TableName() - // check if snowflake if inc < 1 { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid increment for sequence or snowflake %s: %d", tableName, inc) @@ -599,7 +598,6 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { } ret = t.SequenceInfo.NextVal t.SequenceInfo.NextVal += inc - } else if qre.plan.Table.SnowflakeInfo != nil { t := qre.plan.Table t.SnowflakeInfo.Lock() @@ -628,12 +626,11 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { } } - // Generate new id here, return it and update last val with overflow - nextID, err := t.SnowflakeInfo.NextNID(inc, currentMillis()) + // generate first Snowflake id from requested range + ret, err = t.SnowflakeInfo.NextNID(inc, currentMillis()) if err != nil { return nil, vterrors.Wrapf(err, "error generating snowflake with NextNID(%d) %s", inc, tableName) } - ret = int64(nextID) } return &sqltypes.Result{ Fields: sequenceFields, diff --git a/go/vt/vttablet/tabletserver/schema/engine.go b/go/vt/vttablet/tabletserver/schema/engine.go index fd407b54e5e..71b7e46ae84 100644 --- a/go/vt/vttablet/tabletserver/schema/engine.go +++ b/go/vt/vttablet/tabletserver/schema/engine.go @@ -253,13 +253,7 @@ func (se *Engine) MakeNonMaster() { t.SequenceInfo.LastVal = 0 t.SequenceInfo.Unlock() } - if t.SnowflakeInfo != nil { - t.SnowflakeInfo.Lock() - // We don't care about this, since each tablet has its own machine ID. - // t.SnowflakeInfo.NextVal = 0 - // t.SnowflakeInfo.LastVal = 0 - t.SnowflakeInfo.Unlock() - } + // We don't care about Snowflake, since each tablet has its own machine ID. } } diff --git a/go/vt/vttablet/tabletserver/schema/load_table.go b/go/vt/vttablet/tabletserver/schema/load_table.go index 507893d64c6..b2abe1aebb5 100644 --- a/go/vt/vttablet/tabletserver/schema/load_table.go +++ b/go/vt/vttablet/tabletserver/schema/load_table.go @@ -34,7 +34,6 @@ func LoadTable(conn *connpool.DBConn, tableName string, comment string) (*Table, if err := fetchColumns(ta, conn, sqlTableName); err != nil { return nil, err } - fmt.Println("fff comment", comment, "tableName", tableName) switch { case strings.Contains(comment, "vitess_sequence"): ta.Type = Sequence @@ -42,7 +41,6 @@ func LoadTable(conn *connpool.DBConn, tableName string, comment string) (*Table, case strings.Contains(comment, "vitess_snowflake"): ta.Type = Snowflake ta.SnowflakeInfo = &SnowflakeInfo{} - fmt.Println("loaded snowflake table: ", tableName) case strings.Contains(comment, "vitess_message"): if err := loadMessageInfo(ta, comment); err != nil { return nil, err diff --git a/go/vt/vttablet/tabletserver/schema/schema.go b/go/vt/vttablet/tabletserver/schema/schema.go index a5d1a597e67..635333ebb20 100644 --- a/go/vt/vttablet/tabletserver/schema/schema.go +++ b/go/vt/vttablet/tabletserver/schema/schema.go @@ -89,13 +89,12 @@ const ( ) var ( - // default starttime + // default Snowflake start time SnowflakeStartTime = time.Date(2008, 11, 10, 23, 0, 0, 0, time.UTC) ) // SnowflakeInfo contains info specific to sequence tabels. // It must be locked before accessing the values inside. -// If CurVal==LastVal, we have to cache new values. // When the schema is first loaded, the values are all 0, // which will trigger caching on first use. type SnowflakeInfo struct { @@ -113,15 +112,20 @@ func elapsedTime(noms int64, s time.Time) int64 { } func (s *SnowflakeInfo) NextNID(inc int64, currentTimestamp int64) (int64, error) { - fmt.Println("----------") // need to pass timestamo in order to make it more testable // currentTimestamp := currentMillis() var firstSequence, firstTimestamp int64 if s.LastTimestamp < currentTimestamp { + // calculate timestamp and sequence for first id firstTimestamp = currentTimestamp firstSequence = 0 - s.LastTimestamp = currentTimestamp - s.Sequence = 0 + // // calculate timestamp and sequence for last id + // s.LastTimestamp = currentTimestamp + // s.Sequence = 0 + // calculate timestamp and sequence for last id + lastInc := inc - 1 + s.LastTimestamp = currentTimestamp + lastInc/MaxSequence // add overflow to timestamp as ms + s.Sequence = lastInc % MaxSequence // set last sequence } else { if s.LastTimestamp > currentTimestamp { fmt.Println("current timestamp is less than last timestamp, so we are overflowing again") @@ -129,21 +133,22 @@ func (s *SnowflakeInfo) NextNID(inc int64, currentTimestamp int64) (int64, error } else { fmt.Println("Same timestamp", currentTimestamp) } - // calculate first id values + // calculate timestamp and sequence for first id firstInc := s.Sequence + 1 firstTimestamp = currentTimestamp + firstInc/MaxSequence // add overflow to timestamp as ms firstSequence = firstInc % MaxSequence // set first sequence - // calculate last id values + // calculate timestamp and sequence for last id lastInc := s.Sequence + inc s.LastTimestamp = currentTimestamp + lastInc/MaxSequence // add overflow to timestamp as ms s.Sequence = lastInc % MaxSequence // set last sequence } + fmt.Println("firstSequence", firstSequence, "firstTimestamp", firstTimestamp) fmt.Println("lastSequence", s.Sequence, "lastTimestamp", s.LastTimestamp) firstDF := elapsedTime(firstTimestamp, SnowflakeStartTime) - firstId := (uint64(firstDF) << uint64(timestampMoveLength)) | (uint64(s.MachineID) << uint64(machineIDMoveLength)) | uint64(firstSequence) - return int64(firstId), nil + firstId := (firstDF << timestampMoveLength) | (s.MachineID << machineIDMoveLength) | firstSequence + return firstId, nil } // SetMachineID specify the machine ID. It will panic when machined > max limit for 2^10-1. @@ -156,20 +161,6 @@ func (s *SnowflakeInfo) SetMachineID(m int64) error { return nil } -// // ParseID parse snowflake it to SID struct. -// func ParseSnowflakeID(id uint64) SnowflakeID { -// t := id >> uint64(SequenceLength+MachineIDLength) -// sequence := id & uint64(MaxSequence) -// mID := (id & (uint64(MaxMachineID) << SequenceLength)) >> SequenceLength - -// return SnowflakeID{ -// ID: id, -// Sequence: sequence, -// MachineID: mID, -// Timestamp: t, -// } -// } - // MessageInfo contains info specific to message tables. type MessageInfo struct { // Fields stores the field info to be diff --git a/go/vt/vttablet/tabletserver/schema/schema_test.go b/go/vt/vttablet/tabletserver/schema/schema_test.go index 3a7d48178e1..19ec06829d1 100644 --- a/go/vt/vttablet/tabletserver/schema/schema_test.go +++ b/go/vt/vttablet/tabletserver/schema/schema_test.go @@ -13,12 +13,11 @@ func compareSnowflake(t *testing.T, id, wantTimestamp int64, wantSequence int64, gotMachineID := (id & (int64(MaxMachineID) << SequenceLength)) >> SequenceLength fmt.Println("got ", gotTimestamp, gotSequence, gotMachineID) assert.Equal(t, wantSequence, gotSequence) - // assert.Equal(t, wantMachineID, gotMachineID) - // this is a flaky test assert.Equal(t, wantTimestamp, gotTimestamp) + assert.Equal(t, wantMachineID, gotMachineID) } -func TestNextNID(t *testing.T) { +func TestNextNIDSameTimestamp(t *testing.T) { snow := &SnowflakeInfo{} snow.SetMachineID(1) ts := int64(1732711077200) @@ -27,7 +26,6 @@ func TestNextNID(t *testing.T) { if err != nil { t.Fatalf("qre.Execute() = %v, want nil", err) } - // assert.Equal(t, gotId, snow.LastVal) compareSnowflake(t, gotId, ts, 0, 1) // test multiple values within same ms (flaky) @@ -50,6 +48,34 @@ func TestNextNID(t *testing.T) { t.Fatalf("qre.Execute() = %v, want nil", err) } compareSnowflake(t, gotId, ts+1, 910, 1) +} + +func TestNextIDOne(t *testing.T) { + snow := &SnowflakeInfo{} + snow.SetMachineID(1) + ts := int64(1732711077200) + + gotId, err := snow.NextNID(1, ts) + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + // last sequence should be 0 + assert.Equal(t, int64(0), snow.Sequence) + compareSnowflake(t, gotId, ts, 0, 1) + +} + +func TestNextIDTwo(t *testing.T) { + snow := &SnowflakeInfo{} + snow.SetMachineID(1) + ts := int64(1732711077200) + + gotId, err := snow.NextNID(2, ts) + if err != nil { + t.Fatalf("qre.Execute() = %v, want nil", err) + } + // last sequence should be 1 + assert.Equal(t, int64(1), snow.Sequence) + compareSnowflake(t, gotId, ts, 0, 1) - assert.Equal(t, 1, 2) }