diff --git a/sqle/driver/mysql/session/context.go b/sqle/driver/mysql/session/context.go index fea33213b6..4e4df18a3f 100644 --- a/sqle/driver/mysql/session/context.go +++ b/sqle/driver/mysql/session/context.go @@ -713,46 +713,30 @@ func (c *Context) getSelectivityByIndex(indexes []index) (map[string] /*index na return indexSelectivityMap, nil } -func (c *Context) getCachedSelectivity(schema, table, name string) (float64, bool) { - if !c.schemaHasLoad { - // context does not load schema - return -1, false - } - if c.schemas[schema] == nil { - // schema not exist - return -1, false - } - if c.schemas[schema].Tables[table] == nil { - // table not exist +func (c *Context) getSelectivity(schema, table, name string) (float64, bool) { + tableInfo, exist := c.getTable(schema, table) + if !exist { return -1, false } - if c.schemas[schema].Tables[table].Selectivity == nil { + if tableInfo.Selectivity == nil { // selectivity not cached return -1, false } - if selectivity, ok := c.schemas[schema].Tables[table].Selectivity[name]; ok { + if selectivity, ok := tableInfo.Selectivity[name]; ok { return selectivity, true } return -1, false } -func (c *Context) cacheSelectivity(schema, table, name string, selectivity float64) { - if !c.schemaHasLoad { - // context does not load schema - return - } - if c.schemas[schema] == nil { - // schema not exist - return - } - if c.schemas[schema].Tables[table] == nil { - // table not exist +func (c *Context) addSelectivity(schema, table, name string, selectivity float64) { + tableInfo, exist := c.getTable(schema, table) + if !exist { return } - if c.schemas[schema].Tables[table].Selectivity == nil { - c.schemas[schema].Tables[table].Selectivity = make(map[string]float64) + if tableInfo.Selectivity == nil { + tableInfo.Selectivity = make(map[string]float64) } - c.schemas[schema].Tables[table].Selectivity[name] = selectivity + tableInfo.Selectivity[name] = selectivity } func (c *Context) GetSelectivityOfIndex(stmt *ast.TableName, indexNames []string) (map[string]float64, error) { @@ -764,7 +748,7 @@ func (c *Context) GetSelectivityOfIndex(stmt *ast.TableName, indexNames []string cachedIndexSelectivity := make(map[string]float64) indexes := make([]index, 0, len(indexNames)) for _, indexName := range indexNames { - if selectivity, ok := c.getCachedSelectivity(schemaName, tableName, indexName); ok { + if selectivity, ok := c.getSelectivity(schemaName, tableName, indexName); ok { cachedIndexSelectivity[indexName] = selectivity } else { indexes = append(indexes, index{ @@ -780,7 +764,7 @@ func (c *Context) GetSelectivityOfIndex(stmt *ast.TableName, indexNames []string } for indexName, selectivity := range indexSelectivity { - c.cacheSelectivity(schemaName, tableName, indexName, selectivity) + c.addSelectivity(schemaName, tableName, indexName, selectivity) } for indexName, selectivity := range cachedIndexSelectivity { indexSelectivity[indexName] = selectivity @@ -847,7 +831,7 @@ func (c *Context) GetSelectivityOfColumns(stmt *ast.TableName, indexColumns []st cachedIndexSelectivity := make(map[string]float64) columns := make([]column, 0, len(indexColumns)) for _, columnName := range indexColumns { - if selectivity, ok := c.getCachedSelectivity(schemaName, tableName, columnName); ok { + if selectivity, ok := c.getSelectivity(schemaName, tableName, columnName); ok { cachedIndexSelectivity[columnName] = selectivity } else { columns = append(columns, column{ @@ -862,7 +846,7 @@ func (c *Context) GetSelectivityOfColumns(stmt *ast.TableName, indexColumns []st return nil, fmt.Errorf("get selectivity by index error: %v", err) } for indexName, selectivity := range columnSelectivity { - c.cacheSelectivity(schemaName, tableName, indexName, selectivity) + c.addSelectivity(schemaName, tableName, indexName, selectivity) } for indexName, selectivity := range cachedIndexSelectivity { columnSelectivity[indexName] = selectivity