diff options
Diffstat (limited to 'weed/mq/kafka/schema/protobuf_descriptor.go')
| -rw-r--r-- | weed/mq/kafka/schema/protobuf_descriptor.go | 485 |
1 files changed, 485 insertions, 0 deletions
diff --git a/weed/mq/kafka/schema/protobuf_descriptor.go b/weed/mq/kafka/schema/protobuf_descriptor.go new file mode 100644 index 000000000..a0f584114 --- /dev/null +++ b/weed/mq/kafka/schema/protobuf_descriptor.go @@ -0,0 +1,485 @@ +package schema + +import ( + "fmt" + "sync" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" +) + +// ProtobufSchema represents a parsed Protobuf schema with message type information +type ProtobufSchema struct { + FileDescriptorSet *descriptorpb.FileDescriptorSet + MessageDescriptor protoreflect.MessageDescriptor + MessageName string + PackageName string + Dependencies []string +} + +// ProtobufDescriptorParser handles parsing of Confluent Schema Registry Protobuf descriptors +type ProtobufDescriptorParser struct { + mu sync.RWMutex + // Cache for parsed descriptors to avoid re-parsing + descriptorCache map[string]*ProtobufSchema +} + +// NewProtobufDescriptorParser creates a new parser instance +func NewProtobufDescriptorParser() *ProtobufDescriptorParser { + return &ProtobufDescriptorParser{ + descriptorCache: make(map[string]*ProtobufSchema), + } +} + +// ParseBinaryDescriptor parses a Confluent Schema Registry Protobuf binary descriptor +// The input is typically a serialized FileDescriptorSet from the schema registry +func (p *ProtobufDescriptorParser) ParseBinaryDescriptor(binaryData []byte, messageName string) (*ProtobufSchema, error) { + // Check cache first + cacheKey := fmt.Sprintf("%x:%s", binaryData[:min(32, len(binaryData))], messageName) + p.mu.RLock() + if cached, exists := p.descriptorCache[cacheKey]; exists { + p.mu.RUnlock() + // If we have a cached schema but no message descriptor, return the same error + if cached.MessageDescriptor == nil { + return cached, fmt.Errorf("failed to find message descriptor for %s: message descriptor resolution not fully implemented in Phase E1 - found message %s in package %s", messageName, messageName, cached.PackageName) + } + return cached, nil + } + p.mu.RUnlock() + + // Parse the FileDescriptorSet from binary data + var fileDescriptorSet descriptorpb.FileDescriptorSet + if err := proto.Unmarshal(binaryData, &fileDescriptorSet); err != nil { + return nil, fmt.Errorf("failed to unmarshal FileDescriptorSet: %w", err) + } + + // Validate the descriptor set + if err := p.validateDescriptorSet(&fileDescriptorSet); err != nil { + return nil, fmt.Errorf("invalid descriptor set: %w", err) + } + + // If no message name provided, try to find the first available message + if messageName == "" { + messageName = p.findFirstMessageName(&fileDescriptorSet) + if messageName == "" { + return nil, fmt.Errorf("no messages found in FileDescriptorSet") + } + } + + // Find the target message descriptor + messageDesc, packageName, err := p.findMessageDescriptor(&fileDescriptorSet, messageName) + if err != nil { + // For Phase E1, we still cache the FileDescriptorSet even if message resolution fails + // This allows us to test caching behavior and avoid re-parsing the same binary data + schema := &ProtobufSchema{ + FileDescriptorSet: &fileDescriptorSet, + MessageDescriptor: nil, // Not resolved in Phase E1 + MessageName: messageName, + PackageName: packageName, + Dependencies: p.extractDependencies(&fileDescriptorSet), + } + p.mu.Lock() + p.descriptorCache[cacheKey] = schema + p.mu.Unlock() + return schema, fmt.Errorf("failed to find message descriptor for %s: %w", messageName, err) + } + + // Extract dependencies + dependencies := p.extractDependencies(&fileDescriptorSet) + + // Create the schema object + schema := &ProtobufSchema{ + FileDescriptorSet: &fileDescriptorSet, + MessageDescriptor: messageDesc, + MessageName: messageName, + PackageName: packageName, + Dependencies: dependencies, + } + + // Cache the result + p.mu.Lock() + p.descriptorCache[cacheKey] = schema + p.mu.Unlock() + + return schema, nil +} + +// validateDescriptorSet performs basic validation on the FileDescriptorSet +func (p *ProtobufDescriptorParser) validateDescriptorSet(fds *descriptorpb.FileDescriptorSet) error { + if len(fds.File) == 0 { + return fmt.Errorf("FileDescriptorSet contains no files") + } + + for i, file := range fds.File { + if file.Name == nil { + return fmt.Errorf("file descriptor %d has no name", i) + } + if file.Package == nil { + return fmt.Errorf("file descriptor %s has no package", *file.Name) + } + } + + return nil +} + +// findFirstMessageName finds the first message name in the FileDescriptorSet +func (p *ProtobufDescriptorParser) findFirstMessageName(fds *descriptorpb.FileDescriptorSet) string { + for _, file := range fds.File { + if len(file.MessageType) > 0 { + return file.MessageType[0].GetName() + } + } + return "" +} + +// findMessageDescriptor locates a specific message descriptor within the FileDescriptorSet +func (p *ProtobufDescriptorParser) findMessageDescriptor(fds *descriptorpb.FileDescriptorSet, messageName string) (protoreflect.MessageDescriptor, string, error) { + // This is a simplified implementation for Phase E1 + // In a complete implementation, we would: + // 1. Build a complete descriptor registry from the FileDescriptorSet + // 2. Resolve all imports and dependencies + // 3. Handle nested message types and packages correctly + // 4. Support fully qualified message names + + for _, file := range fds.File { + packageName := "" + if file.Package != nil { + packageName = *file.Package + } + + // Search for the message in this file + for _, messageType := range file.MessageType { + if messageType.Name != nil && *messageType.Name == messageName { + // Try to build a proper descriptor from the FileDescriptorProto + fileDesc, err := p.buildFileDescriptor(file) + if err != nil { + return nil, packageName, fmt.Errorf("failed to build file descriptor: %w", err) + } + + // Find the message descriptor in the built file + msgDesc := p.findMessageInFileDescriptor(fileDesc, messageName) + if msgDesc != nil { + return msgDesc, packageName, nil + } + + return nil, packageName, fmt.Errorf("message descriptor built but not found: %s", messageName) + } + + // Search nested messages (simplified) + if nestedDesc := p.searchNestedMessages(messageType, messageName); nestedDesc != nil { + // Try to build descriptor for nested message + fileDesc, err := p.buildFileDescriptor(file) + if err != nil { + return nil, packageName, fmt.Errorf("failed to build file descriptor for nested message: %w", err) + } + + msgDesc := p.findMessageInFileDescriptor(fileDesc, messageName) + if msgDesc != nil { + return msgDesc, packageName, nil + } + + return nil, packageName, fmt.Errorf("nested message descriptor built but not found: %s", messageName) + } + } + } + + return nil, "", fmt.Errorf("message %s not found in descriptor set", messageName) +} + +// buildFileDescriptor builds a protoreflect.FileDescriptor from a FileDescriptorProto +func (p *ProtobufDescriptorParser) buildFileDescriptor(fileProto *descriptorpb.FileDescriptorProto) (protoreflect.FileDescriptor, error) { + // Create a local registry to avoid conflicts + localFiles := &protoregistry.Files{} + + // Build the file descriptor using protodesc + fileDesc, err := protodesc.NewFile(fileProto, localFiles) + if err != nil { + return nil, fmt.Errorf("failed to create file descriptor: %w", err) + } + + return fileDesc, nil +} + +// findMessageInFileDescriptor searches for a message descriptor within a file descriptor +func (p *ProtobufDescriptorParser) findMessageInFileDescriptor(fileDesc protoreflect.FileDescriptor, messageName string) protoreflect.MessageDescriptor { + // Search top-level messages + messages := fileDesc.Messages() + for i := 0; i < messages.Len(); i++ { + msgDesc := messages.Get(i) + if string(msgDesc.Name()) == messageName { + return msgDesc + } + + // Search nested messages + if nestedDesc := p.findNestedMessageDescriptor(msgDesc, messageName); nestedDesc != nil { + return nestedDesc + } + } + + return nil +} + +// findNestedMessageDescriptor recursively searches for nested messages +func (p *ProtobufDescriptorParser) findNestedMessageDescriptor(msgDesc protoreflect.MessageDescriptor, messageName string) protoreflect.MessageDescriptor { + nestedMessages := msgDesc.Messages() + for i := 0; i < nestedMessages.Len(); i++ { + nestedDesc := nestedMessages.Get(i) + if string(nestedDesc.Name()) == messageName { + return nestedDesc + } + + // Recursively search deeper nested messages + if deeperNested := p.findNestedMessageDescriptor(nestedDesc, messageName); deeperNested != nil { + return deeperNested + } + } + + return nil +} + +// searchNestedMessages recursively searches for nested message types +func (p *ProtobufDescriptorParser) searchNestedMessages(messageType *descriptorpb.DescriptorProto, targetName string) *descriptorpb.DescriptorProto { + for _, nested := range messageType.NestedType { + if nested.Name != nil && *nested.Name == targetName { + return nested + } + // Recursively search deeper nesting + if found := p.searchNestedMessages(nested, targetName); found != nil { + return found + } + } + return nil +} + +// extractDependencies extracts the list of dependencies from the FileDescriptorSet +func (p *ProtobufDescriptorParser) extractDependencies(fds *descriptorpb.FileDescriptorSet) []string { + dependencySet := make(map[string]bool) + + for _, file := range fds.File { + for _, dep := range file.Dependency { + dependencySet[dep] = true + } + } + + dependencies := make([]string, 0, len(dependencySet)) + for dep := range dependencySet { + dependencies = append(dependencies, dep) + } + + return dependencies +} + +// GetMessageFields returns information about the fields in the message +func (s *ProtobufSchema) GetMessageFields() ([]FieldInfo, error) { + if s.FileDescriptorSet == nil { + return nil, fmt.Errorf("no FileDescriptorSet available") + } + + // Find the message descriptor for this schema + messageDesc := s.findMessageDescriptor(s.MessageName) + if messageDesc == nil { + return nil, fmt.Errorf("message %s not found in descriptor set", s.MessageName) + } + + // Extract field information + fields := make([]FieldInfo, 0, len(messageDesc.Field)) + for _, field := range messageDesc.Field { + fieldInfo := FieldInfo{ + Name: field.GetName(), + Number: field.GetNumber(), + Type: s.fieldTypeToString(field.GetType()), + Label: s.fieldLabelToString(field.GetLabel()), + } + + // Set TypeName for message/enum types + if field.GetTypeName() != "" { + fieldInfo.TypeName = field.GetTypeName() + } + + fields = append(fields, fieldInfo) + } + + return fields, nil +} + +// FieldInfo represents information about a Protobuf field +type FieldInfo struct { + Name string + Number int32 + Type string + Label string // optional, required, repeated + TypeName string // for message/enum types +} + +// GetFieldByName returns information about a specific field +func (s *ProtobufSchema) GetFieldByName(fieldName string) (*FieldInfo, error) { + fields, err := s.GetMessageFields() + if err != nil { + return nil, err + } + + for _, field := range fields { + if field.Name == fieldName { + return &field, nil + } + } + + return nil, fmt.Errorf("field %s not found", fieldName) +} + +// GetFieldByNumber returns information about a field by its number +func (s *ProtobufSchema) GetFieldByNumber(fieldNumber int32) (*FieldInfo, error) { + fields, err := s.GetMessageFields() + if err != nil { + return nil, err + } + + for _, field := range fields { + if field.Number == fieldNumber { + return &field, nil + } + } + + return nil, fmt.Errorf("field number %d not found", fieldNumber) +} + +// findMessageDescriptor finds a message descriptor by name in the FileDescriptorSet +func (s *ProtobufSchema) findMessageDescriptor(messageName string) *descriptorpb.DescriptorProto { + if s.FileDescriptorSet == nil { + return nil + } + + for _, file := range s.FileDescriptorSet.File { + // Check top-level messages + for _, message := range file.MessageType { + if message.GetName() == messageName { + return message + } + // Check nested messages + if nested := searchNestedMessages(message, messageName); nested != nil { + return nested + } + } + } + + return nil +} + +// searchNestedMessages recursively searches for nested message types +func searchNestedMessages(messageType *descriptorpb.DescriptorProto, targetName string) *descriptorpb.DescriptorProto { + for _, nested := range messageType.NestedType { + if nested.Name != nil && *nested.Name == targetName { + return nested + } + // Recursively search deeper nesting + if found := searchNestedMessages(nested, targetName); found != nil { + return found + } + } + return nil +} + +// fieldTypeToString converts a FieldDescriptorProto_Type to string +func (s *ProtobufSchema) fieldTypeToString(fieldType descriptorpb.FieldDescriptorProto_Type) string { + switch fieldType { + case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: + return "double" + case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: + return "float" + case descriptorpb.FieldDescriptorProto_TYPE_INT64: + return "int64" + case descriptorpb.FieldDescriptorProto_TYPE_UINT64: + return "uint64" + case descriptorpb.FieldDescriptorProto_TYPE_INT32: + return "int32" + case descriptorpb.FieldDescriptorProto_TYPE_FIXED64: + return "fixed64" + case descriptorpb.FieldDescriptorProto_TYPE_FIXED32: + return "fixed32" + case descriptorpb.FieldDescriptorProto_TYPE_BOOL: + return "bool" + case descriptorpb.FieldDescriptorProto_TYPE_STRING: + return "string" + case descriptorpb.FieldDescriptorProto_TYPE_GROUP: + return "group" + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: + return "message" + case descriptorpb.FieldDescriptorProto_TYPE_BYTES: + return "bytes" + case descriptorpb.FieldDescriptorProto_TYPE_UINT32: + return "uint32" + case descriptorpb.FieldDescriptorProto_TYPE_ENUM: + return "enum" + case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: + return "sfixed32" + case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: + return "sfixed64" + case descriptorpb.FieldDescriptorProto_TYPE_SINT32: + return "sint32" + case descriptorpb.FieldDescriptorProto_TYPE_SINT64: + return "sint64" + default: + return "unknown" + } +} + +// fieldLabelToString converts a FieldDescriptorProto_Label to string +func (s *ProtobufSchema) fieldLabelToString(label descriptorpb.FieldDescriptorProto_Label) string { + switch label { + case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL: + return "optional" + case descriptorpb.FieldDescriptorProto_LABEL_REQUIRED: + return "required" + case descriptorpb.FieldDescriptorProto_LABEL_REPEATED: + return "repeated" + default: + return "unknown" + } +} + +// ValidateMessage validates that a message conforms to the schema +func (s *ProtobufSchema) ValidateMessage(messageData []byte) error { + if s.MessageDescriptor == nil { + return fmt.Errorf("no message descriptor available for validation") + } + + // Create a dynamic message from the descriptor + msgType := dynamicpb.NewMessageType(s.MessageDescriptor) + msg := msgType.New() + + // Try to unmarshal the message data + if err := proto.Unmarshal(messageData, msg.Interface()); err != nil { + return fmt.Errorf("message validation failed: %w", err) + } + + // Basic validation passed - the message can be unmarshaled with the schema + return nil +} + +// ClearCache clears the descriptor cache +func (p *ProtobufDescriptorParser) ClearCache() { + p.mu.Lock() + defer p.mu.Unlock() + p.descriptorCache = make(map[string]*ProtobufSchema) +} + +// GetCacheStats returns statistics about the descriptor cache +func (p *ProtobufDescriptorParser) GetCacheStats() map[string]interface{} { + p.mu.RLock() + defer p.mu.RUnlock() + return map[string]interface{}{ + "cached_descriptors": len(p.descriptorCache), + } +} + +// Helper function for min +func min(a, b int) int { + if a < b { + return a + } + return b +} |
