Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

提交pg库表元数据功能代码 #2086

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 85 additions & 87 deletions sqle/pkg/postgresql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,110 +104,108 @@ func (o *DB) ShowSchemaViews(schema string) ([]string, error) {
return getResultSqls(o.Db, query)
}

func (o *DB) ShowCreateTables(database, tableName string, schemas []string) ([]string, error) {
func (o *DB) ShowCreateTables(database, schema, tableName string) ([]string, error) {
tables := make([]string, 0)
for _, schema := range schemas {
tableDDl := fmt.Sprintf("CREATE TABLE %s.%s(", schema, tableName)
if o.IsCaseSensitive {
database = strings.ToLower(database)
schema = strings.ToLower(schema)
tableName = strings.ToLower(tableName)
}
columnsCondition := fmt.Sprintf("table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s'",
database, schema, tableName)
if o.IsCaseSensitive {
columnsCondition = fmt.Sprintf("lower(table_catalog) = '%s' AND lower(table_schema) = '%s' "+
"AND lower(table_name) = '%s'", database, schema, tableName)
}
// 获取列定义,多个英文逗号分割
columns := fmt.Sprintf("SELECT string_agg(column_name || ' ' || "+
"CASE "+
" WHEN data_type IN ('character', 'character varying', 'text') "+
" THEN data_type || '(' || character_maximum_length || ')' "+
" WHEN data_type IN ('numeric', 'decimal') "+
" THEN data_type || '(' || numeric_precision || ',' || numeric_scale || ')' "+
" WHEN data_type IN ('integer', 'smallint', 'bigint') THEN data_type "+
" ELSE data_type "+
" END "+
" || "+
" CASE "+
" WHEN column_default != '' THEN ' DEFAULT ' || column_default ELSE '' END "+
" || "+
" CASE "+
" WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END, ',\n ' ORDER BY ordinal_position) AS columns_sql"+
" FROM information_schema.columns "+
" WHERE %s GROUP BY table_name", columnsCondition)
sqls, err := getResultSqls(o.Db, columns)
if err != nil {
log.Printf("search column definition error:%s\n", err)
return nil, err
}
if len(sqls) == 0 {
tableDDl := fmt.Sprintf("CREATE TABLE %s.%s(", schema, tableName)
if o.IsCaseSensitive {
database = strings.ToLower(database)
schema = strings.ToLower(schema)
tableName = strings.ToLower(tableName)
}
columnsCondition := fmt.Sprintf("table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s'",
database, schema, tableName)
if o.IsCaseSensitive {
columnsCondition = fmt.Sprintf("lower(table_catalog) = '%s' AND lower(table_schema) = '%s' "+
"AND lower(table_name) = '%s'", database, schema, tableName)
}
// 获取列定义,多个英文逗号分割
columns := fmt.Sprintf("SELECT string_agg(column_name || ' ' || "+
"CASE "+
" WHEN data_type IN ('char', 'varchar', 'character', 'character varying', 'text') "+
" THEN data_type || '(' || COALESCE(character_maximum_length, 0) || ')' "+
" WHEN data_type IN ('numeric', 'decimal') "+
" THEN data_type || '(' || COALESCE(numeric_precision, 0) || ',' || COALESCE(numeric_scale, 0) || ')' "+
" WHEN data_type IN ('integer', 'smallint', 'bigint') THEN data_type "+
" ELSE data_type "+
" END "+
" || "+
" CASE "+
" WHEN column_default != '' THEN ' DEFAULT ' || column_default ELSE '' END "+
" || "+
" CASE "+
" WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END, ',\n ' ORDER BY ordinal_position) AS columns_sql"+
" FROM information_schema.columns "+
" WHERE %s GROUP BY table_name", columnsCondition)
sqls, err := getResultSqls(o.Db, columns)
if err != nil {
log.Printf("search column definition error:%s\n", err)
return nil, err
}
if len(sqls) == 0 {
return tables, nil
}
tableDDl += strings.Join(sqls, "")
constraintsCondition := fmt.Sprintf("n.nspname = '%s' AND C.relname = '%s'", schema, tableName)
if o.IsCaseSensitive {
constraintsCondition = fmt.Sprintf("lower(n.nspname) = '%s' "+
"AND lower(C.relname) = '%s'", schema, tableName)
}
// 获取所有约束
constraints := fmt.Sprintf("SELECT 'CONSTRAINT ' || r.conname || ' ' || "+
" pg_catalog.pg_get_constraintdef ( r.OID, TRUE ) AS constraint_definition "+
" FROM pg_catalog.pg_constraint r "+
" JOIN pg_catalog.pg_class C ON C.OID = r.conrelid "+
" JOIN pg_catalog.pg_namespace n ON n.OID = C.relnamespace "+
" WHERE %s", constraintsCondition)
sqls, err = getResultSqls(o.Db, constraints)
if err != nil {
log.Printf("search constraint definition error:%s\n", err)
return nil, err
}
for _, sqlContext := range sqls {
tableDDl += ",\n" + sqlContext
}
tableDDl += ")"
indexesCondition := fmt.Sprintf("schemaname = '%s' and tablename = '%s' ", schema, tableName)
if o.IsCaseSensitive {
indexesCondition = fmt.Sprintf("lower(schemaname) = '%s' and lower(tablename) = '%s'",
schema, tableName)
}
// 获取索引
indexes := fmt.Sprintf("SELECT indexdef AS index_definition FROM pg_indexes "+
" WHERE %s", indexesCondition)
sqls, err = getResultSqls(o.Db, indexes)
if err != nil {
log.Printf("search index definition error:%s\n", err)
return nil, err
}
for _, sqlContent := range sqls {
if strings.Contains(sqlContent, "CREATE UNIQUE INDEX") {
continue
}
tableDDl += strings.Join(sqls, "")
constraintsCondition := fmt.Sprintf("d.datname = '%s' AND n.nspname = '%s' AND C.relname = '%s'",
database, schema, tableName)
if o.IsCaseSensitive {
constraintsCondition = fmt.Sprintf("lower(d.datname) = '%s' AND lower(n.nspname) = '%s' "+
"AND lower(C.relname) = '%s'", database, schema, tableName)
}
// 获取所有约束
constraints := fmt.Sprintf("SELECT 'CONSTRAINT ' || r.conname || ' ' || "+
" pg_catalog.pg_get_constraintdef ( r.OID, TRUE ) AS constraint_definition "+
" FROM pg_catalog.pg_constraint r "+
" JOIN pg_catalog.pg_class C ON C.OID = r.conrelid "+
" JOIN pg_catalog.pg_namespace n ON n.OID = C.relnamespace "+
" JOIN pg_catalog.pg_database d ON d.datname = n.nspname "+
" WHERE %s", constraintsCondition)
sqls, err = getResultSqls(o.Db, constraints)
if err != nil {
log.Printf("search constraint definition error:%s\n", err)
return nil, err
}
for _, sqlContext := range sqls {
tableDDl += ",\n" + sqlContext
}
tableDDl += ")"
indexesCondition := fmt.Sprintf("schemaname = '%s' and tablename = '%s' ", schema, tableName)
if o.IsCaseSensitive {
indexesCondition = fmt.Sprintf("lower(schemaname) = '%s' and lower(tablename) = '%s'",
schema, tableName)
}
// 获取索引
indexes := fmt.Sprintf("SELECT indexdef AS index_definition FROM pg_indexes "+
" WHERE %s", indexesCondition)
sqls, err = getResultSqls(o.Db, indexes)
if err != nil {
log.Printf("search index definition error:%s\n", err)
return nil, err
}
for _, sqlContent := range sqls {
if strings.Contains(sqlContent, "CREATE UNIQUE INDEX") {
continue
}
tableDDl += ";\n" + sqlContent
}
tables = append(tables, tableDDl)
tableDDl += ";\n" + sqlContent
}
tables = append(tables, tableDDl)
return tables, nil
}

func (o *DB) ShowCreateViews(database, tableName string) ([]string, error) {
func (o *DB) ShowCreateViews(database, schema, tableName string) ([]string, error) {
query := fmt.Sprintf(
"SELECT 'CREATE OR REPLACE VIEW ' || table_schema || '.' || table_name || ' AS ' || view_definition"+
" AS create_view_statement "+
" FROM information_schema.views WHERE table_catalog = '%s' AND table_name = '%s'",
database, tableName)
" FROM information_schema.views "+
" WHERE table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s'",
database, schema, tableName)

if o.IsCaseSensitive {
database = strings.ToLower(database)
tableName = strings.ToLower(tableName)
query = fmt.Sprintf(
"SELECT 'CREATE OR REPLACE VIEW ' || table_schema || '.' || table_name || ' AS ' || view_definition"+
" AS create_view_statement "+
" FROM information_schema.views WHERE lower(table_catalog) = '%s' AND lower(table_name) = '%s'",
database, tableName)
" FROM information_schema.views "+
" WHERE lower(table_catalog) = '%s' AND lower(table_schema) = '%s' AND lower(table_name) = '%s'",
database, schema, tableName)
}
return getResultSqls(o.Db, query)
}
Expand Down
90 changes: 61 additions & 29 deletions sqle/server/auditplan/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1595,20 +1595,6 @@ func (at *PostgreSQLSchemaMetaTask) collectorDo() {
defer db.Close()
db.IsCaseSensitive = db.GetCaseSensitive()

tables, err := db.ShowSchemaTables(at.ap.InstanceDatabase)
if err != nil {
at.logger.Errorf("get schema table fail, error: %s", err)
return
}
var views []string
if at.ap.Params.GetParam("collect_view").Bool() {
views, err = db.ShowSchemaViews(at.ap.InstanceDatabase)
if err != nil {
at.logger.Errorf("get schema view fail, error: %s", err)
return
}
}

schemas, err := db.GetAllUserSchemas()
if err != nil {
at.logger.Errorf("get database=%s schemas error: %s", at.ap.InstanceDatabase, err)
Expand All @@ -1619,23 +1605,69 @@ func (at *PostgreSQLSchemaMetaTask) collectorDo() {
return
}

sqls := make([]string, 0, len(tables)+len(views))
for _, table := range tables {
tableSqls, err := db.ShowCreateTables(at.ap.InstanceDatabase, table, schemas)
if err != nil {
at.logger.Errorf("show create table fail, error: %s", err)
return
}
sqls = append(sqls, tableSqls...)
wg := sync.WaitGroup{}
wg.Add(len(schemas) * 2)
tableMutex := sync.Mutex{}
viewMutex := sync.Mutex{}
sqls := make([]string, 0)
finalTableSqls := make([]string, 0)
finalViewSqls := make([]string, 0)
for _, schema := range schemas {
go func(schema string) {
defer wg.Done()
tables, err := db.ShowSchemaTables(schema)
ColdWaterLW marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
at.logger.Errorf("get schema table fail, error: %s", err)
return
}
for _, table := range tables {
tableSqls, err := db.ShowCreateTables(at.ap.InstanceDatabase, schema, table)
if err != nil {
at.logger.Errorf("show create table fail, error: %s", err)
return
}
tableMutex.Lock()
if len(tableSqls) > 0 {
finalTableSqls = append(finalTableSqls, tableSqls...)
}
tableMutex.Unlock()
}
}(schema)

go func(schema string) {
defer wg.Done()
var views []string
if at.ap.Params.GetParam("collect_view").Bool() {
views, err = db.ShowSchemaViews(schema)
ColdWaterLW marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
at.logger.Errorf("get schema view fail, error: %s", err)
return
}
}
for _, view := range views {
viewSqls, err := db.ShowCreateViews(at.ap.InstanceDatabase, schema, view)
if err != nil {
at.logger.Errorf("show create view fail, error: %s", err)
return
}
viewMutex.Lock()
if len(viewSqls) > 0 {
finalViewSqls = append(finalViewSqls, viewSqls...)
}
viewMutex.Unlock()
}
}(schema)
}
for _, view := range views {
viewSqls, err := db.ShowCreateViews(at.ap.InstanceDatabase, view)
if err != nil {
at.logger.Errorf("show create view fail, error: %s", err)
return
}
sqls = append(sqls, viewSqls...)
wg.Wait()

if len(finalTableSqls) > 0 {
sqls = append(sqls, finalTableSqls...)
}

if len(finalViewSqls) > 0 {
sqls = append(sqls, finalViewSqls...)
}

if len(sqls) > 0 {
err = at.persist.OverrideAuditPlanSQLs(at.ap.ID, convertRawSQLToModelSQLs(sqls))
if err != nil {
Expand Down