aboutsummaryrefslogtreecommitdiff
path: root/weed/query/engine/arithmetic_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/query/engine/arithmetic_test.go')
-rw-r--r--weed/query/engine/arithmetic_test.go275
1 files changed, 275 insertions, 0 deletions
diff --git a/weed/query/engine/arithmetic_test.go b/weed/query/engine/arithmetic_test.go
new file mode 100644
index 000000000..4bf8813c6
--- /dev/null
+++ b/weed/query/engine/arithmetic_test.go
@@ -0,0 +1,275 @@
+package engine
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
+)
+
+func TestArithmeticExpressionParsing(t *testing.T) {
+ tests := []struct {
+ name string
+ expression string
+ expectNil bool
+ leftCol string
+ rightCol string
+ operator string
+ }{
+ {
+ name: "simple addition",
+ expression: "id+user_id",
+ expectNil: false,
+ leftCol: "id",
+ rightCol: "user_id",
+ operator: "+",
+ },
+ {
+ name: "simple subtraction",
+ expression: "col1-col2",
+ expectNil: false,
+ leftCol: "col1",
+ rightCol: "col2",
+ operator: "-",
+ },
+ {
+ name: "multiplication with spaces",
+ expression: "a * b",
+ expectNil: false,
+ leftCol: "a",
+ rightCol: "b",
+ operator: "*",
+ },
+ {
+ name: "string concatenation",
+ expression: "first_name||last_name",
+ expectNil: false,
+ leftCol: "first_name",
+ rightCol: "last_name",
+ operator: "||",
+ },
+ {
+ name: "string concatenation with spaces",
+ expression: "prefix || suffix",
+ expectNil: false,
+ leftCol: "prefix",
+ rightCol: "suffix",
+ operator: "||",
+ },
+ {
+ name: "not arithmetic",
+ expression: "simple_column",
+ expectNil: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Use CockroachDB parser to parse the expression
+ cockroachParser := NewCockroachSQLParser()
+ dummySelect := fmt.Sprintf("SELECT %s", tt.expression)
+ stmt, err := cockroachParser.ParseSQL(dummySelect)
+
+ var result *ArithmeticExpr
+ if err == nil {
+ if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 {
+ if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok {
+ if arithmeticExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok {
+ result = arithmeticExpr
+ }
+ }
+ }
+ }
+
+ if tt.expectNil {
+ if result != nil {
+ t.Errorf("Expected nil for %s, got %v", tt.expression, result)
+ }
+ return
+ }
+
+ if result == nil {
+ t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression)
+ return
+ }
+
+ if result.Operator != tt.operator {
+ t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator)
+ }
+
+ // Check left operand
+ if leftCol, ok := result.Left.(*ColName); ok {
+ if leftCol.Name.String() != tt.leftCol {
+ t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String())
+ }
+ } else {
+ t.Errorf("Expected left operand to be ColName, got %T", result.Left)
+ }
+
+ // Check right operand
+ if rightCol, ok := result.Right.(*ColName); ok {
+ if rightCol.Name.String() != tt.rightCol {
+ t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String())
+ }
+ } else {
+ t.Errorf("Expected right operand to be ColName, got %T", result.Right)
+ }
+ })
+ }
+}
+
+func TestArithmeticExpressionEvaluation(t *testing.T) {
+ engine := NewSQLEngine("")
+
+ // Create test data
+ result := HybridScanResult{
+ Values: map[string]*schema_pb.Value{
+ "id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}},
+ "user_id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}},
+ "price": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}},
+ "qty": {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}},
+ "first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}},
+ "last_name": {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}},
+ "prefix": {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}},
+ "suffix": {Kind: &schema_pb.Value_StringValue{StringValue: "World"}},
+ },
+ }
+
+ tests := []struct {
+ name string
+ expression string
+ expected interface{}
+ }{
+ {
+ name: "integer addition",
+ expression: "id+user_id",
+ expected: int64(15),
+ },
+ {
+ name: "integer subtraction",
+ expression: "id-user_id",
+ expected: int64(5),
+ },
+ {
+ name: "mixed types multiplication",
+ expression: "price*qty",
+ expected: float64(76.5),
+ },
+ {
+ name: "string concatenation",
+ expression: "first_name||last_name",
+ expected: "JohnDoe",
+ },
+ {
+ name: "string concatenation with spaces",
+ expression: "prefix || suffix",
+ expected: "HelloWorld",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Parse the arithmetic expression using CockroachDB parser
+ cockroachParser := NewCockroachSQLParser()
+ dummySelect := fmt.Sprintf("SELECT %s", tt.expression)
+ stmt, err := cockroachParser.ParseSQL(dummySelect)
+ if err != nil {
+ t.Fatalf("Failed to parse expression %s: %v", tt.expression, err)
+ }
+
+ var arithmeticExpr *ArithmeticExpr
+ if selectStmt, ok := stmt.(*SelectStatement); ok && len(selectStmt.SelectExprs) > 0 {
+ if aliasedExpr, ok := selectStmt.SelectExprs[0].(*AliasedExpr); ok {
+ if arithExpr, ok := aliasedExpr.Expr.(*ArithmeticExpr); ok {
+ arithmeticExpr = arithExpr
+ }
+ }
+ }
+
+ if arithmeticExpr == nil {
+ t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression)
+ }
+
+ // Evaluate the expression
+ value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result)
+ if err != nil {
+ t.Fatalf("Failed to evaluate expression: %v", err)
+ }
+
+ if value == nil {
+ t.Fatalf("Got nil value for expression: %s", tt.expression)
+ }
+
+ // Check the result
+ switch expected := tt.expected.(type) {
+ case int64:
+ if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok {
+ if intVal.Int64Value != expected {
+ t.Errorf("Expected %d, got %d", expected, intVal.Int64Value)
+ }
+ } else {
+ t.Errorf("Expected int64 result, got %T", value.Kind)
+ }
+ case float64:
+ if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok {
+ if doubleVal.DoubleValue != expected {
+ t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue)
+ }
+ } else {
+ t.Errorf("Expected double result, got %T", value.Kind)
+ }
+ case string:
+ if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok {
+ if stringVal.StringValue != expected {
+ t.Errorf("Expected %s, got %s", expected, stringVal.StringValue)
+ }
+ } else {
+ t.Errorf("Expected string result, got %T", value.Kind)
+ }
+ }
+ })
+ }
+}
+
+func TestSelectArithmeticExpression(t *testing.T) {
+ // Test parsing a SELECT with arithmetic and string concatenation expressions
+ stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table")
+ if err != nil {
+ t.Fatalf("Failed to parse SQL: %v", err)
+ }
+
+ selectStmt := stmt.(*SelectStatement)
+ if len(selectStmt.SelectExprs) != 3 {
+ t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
+ }
+
+ // Check first expression (id+user_id)
+ aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr)
+ if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok {
+ if arithmeticExpr1.Operator != "+" {
+ t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator)
+ }
+ } else {
+ t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr)
+ }
+
+ // Check second expression (user_id*2)
+ aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr)
+ if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok {
+ if arithmeticExpr2.Operator != "*" {
+ t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator)
+ }
+ } else {
+ t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr)
+ }
+
+ // Check third expression (first_name||last_name)
+ aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr)
+ if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok {
+ if arithmeticExpr3.Operator != "||" {
+ t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator)
+ }
+ } else {
+ t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr)
+ }
+}