From 5c1ced1059fdd094a180707cf1b4317fb49061ee Mon Sep 17 00:00:00 2001 From: linxiaotao Date: Mon, 9 Oct 2023 10:54:25 +0800 Subject: [PATCH] test: unit test for util ColumeNameVisitor @winfredLIN --- sqle/driver/mysql/util/visitor_test.go | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/sqle/driver/mysql/util/visitor_test.go b/sqle/driver/mysql/util/visitor_test.go index 17ef2fa018..b34266c8e9 100644 --- a/sqle/driver/mysql/util/visitor_test.go +++ b/sqle/driver/mysql/util/visitor_test.go @@ -105,3 +105,35 @@ func TestSelectFieldExtractor(t *testing.T) { }) } } + +func TestColumeNameVisitor(t *testing.T) { + tests := []struct { + input string + columnCount uint + }{ + {"SELECT * FROM t1", 0}, //不包含列 + {"SELECT a,b,c FROM t1 WHERE id > 1", 4}, //使用不等号 + {"SELECT COUNT(*) FROM t1", 0}, //使用函数并不包含列 + {"SELECT a,COUNT(*) FROM t1 GROUP BY a", 2}, //使用函数包含列 + {"SELECT * FROM table1 INNER JOIN table2 ON table1.id = table2.table1_id", 2}, //使用JOIN + {"SELECT * FROM table1 WHERE id IN ( SELECT id FROM table2 WHERE age > 30)", 3}, //使用子查询 + {"SELECT UPPER(name), LENGTH(comments) FROM table1", 2}, //使用函数 + {"SELECT CAST(price AS DECIMAL(10,2))FROM products", 1}, //使用类型转换 + {"SELECT * FROM table1 INNER JOIN table2 ON table1.id = table2.table1_id INNER JOIN table3 ON table2.id = table3.table2_id", 4}, //使用JOIN嵌套 + {"SELECT column1 AS alias1, column2 AS alias2 FROM table1", 2}, //使用列别名 + {"SELECT column1 + column2 AS sum_columns FROM table1", 2}, + {"SELECT t1.column1 AS t1_col1, t2.column2 AS t2_col2 FROM table1 t1 INNER JOIN table2 t2 ON t1.id = t2.t1_id", 4}, //不带AS的表别名 + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + stmt, err := parser.New().ParseOneStmt(tt.input, "", "") + assert.NoError(t, err) + + visitor := &ColumeNameVisitor{} + stmt.Accept(visitor) + + assert.Equal(t, tt.columnCount, uint(len(visitor.ColumeNameList))) + }) + } +}