diff options
Diffstat (limited to 'weed/query/engine/query_parsing_test.go')
| -rw-r--r-- | weed/query/engine/query_parsing_test.go | 564 |
1 files changed, 564 insertions, 0 deletions
diff --git a/weed/query/engine/query_parsing_test.go b/weed/query/engine/query_parsing_test.go new file mode 100644 index 000000000..ffeaadbc5 --- /dev/null +++ b/weed/query/engine/query_parsing_test.go @@ -0,0 +1,564 @@ +package engine + +import ( + "testing" +) + +func TestParseSQL_COUNT_Functions(t *testing.T) { + tests := []struct { + name string + sql string + wantErr bool + validate func(t *testing.T, stmt Statement) + }{ + { + name: "COUNT(*) basic", + sql: "SELECT COUNT(*) FROM test_table", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt, ok := stmt.(*SelectStatement) + if !ok { + t.Fatalf("Expected *SelectStatement, got %T", stmt) + } + + if len(selectStmt.SelectExprs) != 1 { + t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs)) + } + + aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr) + if !ok { + t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0]) + } + + funcExpr, ok := aliasedExpr.Expr.(*FuncExpr) + if !ok { + t.Fatalf("Expected *FuncExpr, got %T", aliasedExpr.Expr) + } + + if funcExpr.Name.String() != "COUNT" { + t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String()) + } + + if len(funcExpr.Exprs) != 1 { + t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs)) + } + + starExpr, ok := funcExpr.Exprs[0].(*StarExpr) + if !ok { + t.Errorf("Expected *StarExpr argument, got %T", funcExpr.Exprs[0]) + } + _ = starExpr // Use the variable to avoid unused variable error + }, + }, + { + name: "COUNT(column_name)", + sql: "SELECT COUNT(user_id) FROM users", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt, ok := stmt.(*SelectStatement) + if !ok { + t.Fatalf("Expected *SelectStatement, got %T", stmt) + } + + aliasedExpr := selectStmt.SelectExprs[0].(*AliasedExpr) + funcExpr := aliasedExpr.Expr.(*FuncExpr) + + if funcExpr.Name.String() != "COUNT" { + t.Errorf("Expected function name 'COUNT', got '%s'", funcExpr.Name.String()) + } + + if len(funcExpr.Exprs) != 1 { + t.Fatalf("Expected 1 function argument, got %d", len(funcExpr.Exprs)) + } + + argExpr, ok := funcExpr.Exprs[0].(*AliasedExpr) + if !ok { + t.Errorf("Expected *AliasedExpr argument, got %T", funcExpr.Exprs[0]) + } + + colName, ok := argExpr.Expr.(*ColName) + if !ok { + t.Errorf("Expected *ColName, got %T", argExpr.Expr) + } + + if colName.Name.String() != "user_id" { + t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String()) + } + }, + }, + { + name: "Multiple aggregate functions", + sql: "SELECT COUNT(*), SUM(amount), AVG(score) FROM transactions", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt, ok := stmt.(*SelectStatement) + if !ok { + t.Fatalf("Expected *SelectStatement, got %T", stmt) + } + + if len(selectStmt.SelectExprs) != 3 { + t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs)) + } + + // Verify COUNT(*) + countExpr := selectStmt.SelectExprs[0].(*AliasedExpr) + countFunc := countExpr.Expr.(*FuncExpr) + if countFunc.Name.String() != "COUNT" { + t.Errorf("Expected first function to be COUNT, got %s", countFunc.Name.String()) + } + + // Verify SUM(amount) + sumExpr := selectStmt.SelectExprs[1].(*AliasedExpr) + sumFunc := sumExpr.Expr.(*FuncExpr) + if sumFunc.Name.String() != "SUM" { + t.Errorf("Expected second function to be SUM, got %s", sumFunc.Name.String()) + } + + // Verify AVG(score) + avgExpr := selectStmt.SelectExprs[2].(*AliasedExpr) + avgFunc := avgExpr.Expr.(*FuncExpr) + if avgFunc.Name.String() != "AVG" { + t.Errorf("Expected third function to be AVG, got %s", avgFunc.Name.String()) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := ParseSQL(tt.sql) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if tt.validate != nil { + tt.validate(t, stmt) + } + }) + } +} + +func TestParseSQL_SELECT_Expressions(t *testing.T) { + tests := []struct { + name string + sql string + wantErr bool + validate func(t *testing.T, stmt Statement) + }{ + { + name: "SELECT * FROM table", + sql: "SELECT * FROM users", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if len(selectStmt.SelectExprs) != 1 { + t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs)) + } + + _, ok := selectStmt.SelectExprs[0].(*StarExpr) + if !ok { + t.Errorf("Expected *StarExpr, got %T", selectStmt.SelectExprs[0]) + } + }, + }, + { + name: "SELECT column FROM table", + sql: "SELECT user_id FROM users", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if len(selectStmt.SelectExprs) != 1 { + t.Fatalf("Expected 1 select expression, got %d", len(selectStmt.SelectExprs)) + } + + aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr) + if !ok { + t.Fatalf("Expected *AliasedExpr, got %T", selectStmt.SelectExprs[0]) + } + + colName, ok := aliasedExpr.Expr.(*ColName) + if !ok { + t.Fatalf("Expected *ColName, got %T", aliasedExpr.Expr) + } + + if colName.Name.String() != "user_id" { + t.Errorf("Expected column name 'user_id', got '%s'", colName.Name.String()) + } + }, + }, + { + name: "SELECT multiple columns", + sql: "SELECT user_id, name, email FROM users", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if len(selectStmt.SelectExprs) != 3 { + t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs)) + } + + expectedColumns := []string{"user_id", "name", "email"} + for i, expected := range expectedColumns { + aliasedExpr := selectStmt.SelectExprs[i].(*AliasedExpr) + colName := aliasedExpr.Expr.(*ColName) + if colName.Name.String() != expected { + t.Errorf("Expected column %d to be '%s', got '%s'", i, expected, colName.Name.String()) + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := ParseSQL(tt.sql) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if tt.validate != nil { + tt.validate(t, stmt) + } + }) + } +} + +func TestParseSQL_WHERE_Clauses(t *testing.T) { + tests := []struct { + name string + sql string + wantErr bool + validate func(t *testing.T, stmt Statement) + }{ + { + name: "WHERE with simple comparison", + sql: "SELECT * FROM users WHERE age > 18", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Where == nil { + t.Fatal("Expected WHERE clause, got nil") + } + + // Just verify we have a WHERE clause with an expression + if selectStmt.Where.Expr == nil { + t.Error("Expected WHERE expression, got nil") + } + }, + }, + { + name: "WHERE with AND condition", + sql: "SELECT * FROM users WHERE age > 18 AND status = 'active'", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Where == nil { + t.Fatal("Expected WHERE clause, got nil") + } + + // Verify we have an AND expression + andExpr, ok := selectStmt.Where.Expr.(*AndExpr) + if !ok { + t.Errorf("Expected *AndExpr, got %T", selectStmt.Where.Expr) + } + _ = andExpr // Use variable to avoid unused error + }, + }, + { + name: "WHERE with OR condition", + sql: "SELECT * FROM users WHERE age < 18 OR age > 65", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Where == nil { + t.Fatal("Expected WHERE clause, got nil") + } + + // Verify we have an OR expression + orExpr, ok := selectStmt.Where.Expr.(*OrExpr) + if !ok { + t.Errorf("Expected *OrExpr, got %T", selectStmt.Where.Expr) + } + _ = orExpr // Use variable to avoid unused error + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := ParseSQL(tt.sql) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if tt.validate != nil { + tt.validate(t, stmt) + } + }) + } +} + +func TestParseSQL_LIMIT_Clauses(t *testing.T) { + tests := []struct { + name string + sql string + wantErr bool + validate func(t *testing.T, stmt Statement) + }{ + { + name: "LIMIT with number", + sql: "SELECT * FROM users LIMIT 10", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Limit == nil { + t.Fatal("Expected LIMIT clause, got nil") + } + + if selectStmt.Limit.Rowcount == nil { + t.Error("Expected LIMIT rowcount, got nil") + } + + // Verify no OFFSET is set + if selectStmt.Limit.Offset != nil { + t.Error("Expected OFFSET to be nil for LIMIT-only query") + } + + sqlVal, ok := selectStmt.Limit.Rowcount.(*SQLVal) + if !ok { + t.Errorf("Expected *SQLVal, got %T", selectStmt.Limit.Rowcount) + } + + if sqlVal.Type != IntVal { + t.Errorf("Expected IntVal type, got %d", sqlVal.Type) + } + + if string(sqlVal.Val) != "10" { + t.Errorf("Expected limit value '10', got '%s'", string(sqlVal.Val)) + } + }, + }, + { + name: "LIMIT with OFFSET", + sql: "SELECT * FROM users LIMIT 10 OFFSET 5", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Limit == nil { + t.Fatal("Expected LIMIT clause, got nil") + } + + // Verify LIMIT value + if selectStmt.Limit.Rowcount == nil { + t.Error("Expected LIMIT rowcount, got nil") + } + + limitVal, ok := selectStmt.Limit.Rowcount.(*SQLVal) + if !ok { + t.Errorf("Expected *SQLVal for LIMIT, got %T", selectStmt.Limit.Rowcount) + } + + if limitVal.Type != IntVal { + t.Errorf("Expected IntVal type for LIMIT, got %d", limitVal.Type) + } + + if string(limitVal.Val) != "10" { + t.Errorf("Expected limit value '10', got '%s'", string(limitVal.Val)) + } + + // Verify OFFSET value + if selectStmt.Limit.Offset == nil { + t.Fatal("Expected OFFSET clause, got nil") + } + + offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal) + if !ok { + t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset) + } + + if offsetVal.Type != IntVal { + t.Errorf("Expected IntVal type for OFFSET, got %d", offsetVal.Type) + } + + if string(offsetVal.Val) != "5" { + t.Errorf("Expected offset value '5', got '%s'", string(offsetVal.Val)) + } + }, + }, + { + name: "LIMIT with OFFSET zero", + sql: "SELECT * FROM users LIMIT 5 OFFSET 0", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Limit == nil { + t.Fatal("Expected LIMIT clause, got nil") + } + + // Verify OFFSET is 0 + if selectStmt.Limit.Offset == nil { + t.Fatal("Expected OFFSET clause, got nil") + } + + offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal) + if !ok { + t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset) + } + + if string(offsetVal.Val) != "0" { + t.Errorf("Expected offset value '0', got '%s'", string(offsetVal.Val)) + } + }, + }, + { + name: "LIMIT with large OFFSET", + sql: "SELECT * FROM users LIMIT 100 OFFSET 1000", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + selectStmt := stmt.(*SelectStatement) + if selectStmt.Limit == nil { + t.Fatal("Expected LIMIT clause, got nil") + } + + // Verify large OFFSET value + offsetVal, ok := selectStmt.Limit.Offset.(*SQLVal) + if !ok { + t.Errorf("Expected *SQLVal for OFFSET, got %T", selectStmt.Limit.Offset) + } + + if string(offsetVal.Val) != "1000" { + t.Errorf("Expected offset value '1000', got '%s'", string(offsetVal.Val)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := ParseSQL(tt.sql) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if tt.validate != nil { + tt.validate(t, stmt) + } + }) + } +} + +func TestParseSQL_SHOW_Statements(t *testing.T) { + tests := []struct { + name string + sql string + wantErr bool + validate func(t *testing.T, stmt Statement) + }{ + { + name: "SHOW DATABASES", + sql: "SHOW DATABASES", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + showStmt, ok := stmt.(*ShowStatement) + if !ok { + t.Fatalf("Expected *ShowStatement, got %T", stmt) + } + + if showStmt.Type != "databases" { + t.Errorf("Expected type 'databases', got '%s'", showStmt.Type) + } + }, + }, + { + name: "SHOW TABLES", + sql: "SHOW TABLES", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + showStmt, ok := stmt.(*ShowStatement) + if !ok { + t.Fatalf("Expected *ShowStatement, got %T", stmt) + } + + if showStmt.Type != "tables" { + t.Errorf("Expected type 'tables', got '%s'", showStmt.Type) + } + }, + }, + { + name: "SHOW TABLES FROM database", + sql: "SHOW TABLES FROM \"test_db\"", + wantErr: false, + validate: func(t *testing.T, stmt Statement) { + showStmt, ok := stmt.(*ShowStatement) + if !ok { + t.Fatalf("Expected *ShowStatement, got %T", stmt) + } + + if showStmt.Type != "tables" { + t.Errorf("Expected type 'tables', got '%s'", showStmt.Type) + } + + if showStmt.Schema != "test_db" { + t.Errorf("Expected schema 'test_db', got '%s'", showStmt.Schema) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := ParseSQL(tt.sql) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if tt.validate != nil { + tt.validate(t, stmt) + } + }) + } +} |
