123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- from __future__ import annotations
- from functools import partial
- from typing import ClassVar, List, Any, Dict
- from typing import Callable
- from google.protobuf import message, descriptor_pb2
- from google.protobuf.descriptor import Descriptor, FieldDescriptor, FileDescriptor
- from google.protobuf.descriptor_pb2 import FieldDescriptorProto,DescriptorProto,EnumDescriptorProto
- import google.protobuf.descriptor_pool as descriptor_pool
- import logging
- import copy
- # import custom_options_pb2 as custom
- RendererType = Callable[['ProtoElement'], Dict]
- class ProtoElement:
- childs:List[ProtoElement]
- descriptor:Descriptor|FieldDescriptor
- comments: Dict[str,str]
- enum_type:EnumDescriptorProto
- _comments: Dict[str,str] ={}
- pool:descriptor_pool.DescriptorPool
- prototypes: dict[str, type[message.Message]]
- renderer:RendererType
- package:str
- file:FileDescriptor
- message:str
- _positions: Dict[str,tuple]
- position: tuple
- options:Dict[str,any]
- _message_instance:ClassVar
- @classmethod
- def set_prototypes(cls,prototypes:dict[str, type[message.Message]]):
- cls.prototypes = prototypes
- @classmethod
- def set_comments_base(cls,comments:Dict[str,str]):
- cls._comments = comments
- @classmethod
- def set_positions_base(cls,positions:Dict[str,tuple]):
- cls._positions = positions
- @classmethod
- def set_pool(cls,pool:descriptor_pool.DescriptorPool):
- cls.pool = pool
- @classmethod
- def set_logger(cls,logger = None):
- if not logger and not cls.logger:
- cls.logger = logging.getLogger(__name__)
- logging.basicConfig(
- level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
- )
- elif logger:
- cls.logger = logger
- @classmethod
- def set_render(cls,render):
- cls.render_class = render
- def __init__(self, descriptor: Descriptor|FieldDescriptor, parent=None):
- ProtoElement.set_logger()
- self.descriptor = descriptor
- self.file = descriptor.file
- self.package = getattr(descriptor,"file",parent).package
- self.descriptorname = descriptor.name
- self.json_name = getattr(descriptor,'json_name','')
- self.type_name = getattr(descriptor,'type_name',descriptor.name)
-
- self.parent = parent
- self.fullname = descriptor.full_name
- self.type = getattr(descriptor,'type',FieldDescriptor.TYPE_MESSAGE)
-
- if self.type ==FieldDescriptor.TYPE_MESSAGE:
- try:
- self._message_instance = self.prototypes[self.descriptor.message_type.full_name]()
- # self.logger.debug(f'Found instance for {self.descriptor.message_type.full_name}')
- except:
- # self.logger.error(f'Could not find instance for {self.descriptor.full_name}')
- self._message_instance = self.prototypes[self.descriptor.full_name]()
- self.label = getattr(descriptor,'label',None)
- self.childs = []
- if descriptor.has_options:
- self.options = {descr.name: value for descr, value in descriptor.GetOptions().ListFields()}
- else:
- self.options = {}
- try:
- if descriptor.containing_type.has_options:
- self.options.update({descr.name: value for descr, value in descriptor.containing_type.GetOptions().ListFields()})
- except:
- pass
-
- self.render = partial(self.render_class, self)
- self.comments = {comment.split('.')[-1]:self._comments[comment] for comment in self._comments.keys() if comment.startswith(self.path)}
- self.position = self._positions.get(self.path)
- @property
- def cpp_type(self)->str:
- return f'{self.package}_{self.descriptor.containing_type.name}'
- @property
- def cpp_member(self)->str:
- return self.name
- @property
- def cpp_type_member_prefix(self)->str:
- return f'{self.cpp_type}_{self.cpp_member}'
- @property
- def cpp_type_member(self)->str:
- return f'{self.cpp_type}.{self.cpp_member}'
- @property
- def main_message(self)->bool:
- return self.parent == None
- @property
- def parent(self)->ProtoElement:
- return self._parent
- @parent.setter
- def parent(self,value:ProtoElement):
- self._parent = value
- if value:
- self._parent.childs.append(self)
- @property
- def root(self)->ProtoElement:
- return self if not self.parent else self.parent
- @property
- def enum_type(self)->EnumDescriptorProto:
- return self.descriptor.enum_type
- @property
- def cpp_root(self):
- return f'{self.cpp_type}_ROOT'
- @property
- def cpp_child(self):
- return f'{self.cpp_type}_CHILD'
- @property
- def proto_file_line(self):
- # Accessing file descriptor to get source code info, adjusted for proper context
- if self.position:
- start_line, start_column, end_line = self.position
- return f"{self.file.name}:{start_line}"
- else:
- return f"{self.file.name}"
-
-
- @property
- def message_instance(self):
- return getattr(self,'_message_instance',getattr(self.parent,'message_instance',None))
- @property
- def new_message_instance(self):
- if self.type == FieldDescriptor.TYPE_MESSAGE:
- try:
- # Try to create a new instance using the full name of the message type
- return self.prototypes[self.descriptor.message_type.full_name]()
- except KeyError:
- # If the above fails, use an alternative method to create a new instance
- # Log the error if necessary
- # self.logger.error(f'Could not find instance for {self.descriptor.full_name}')
- return self.prototypes[self.descriptor.full_name]()
- else:
- # Return None or raise an exception if the type is not a message
- return None
- @property
- def tree(self):
- childs = '->('+', '.join(c.tree for c in self.childs ) + ')' if len(self.childs)>0 else ''
- return f'{self.name}{childs}'
- @property
- def name(self):
- return self.descriptorname if len(self.descriptorname)>0 else self.parent.name if self.parent else self.package
- @property
- def enum_values(self)->List[str]:
- return [n.name for n in getattr(self.enum_type,"values",getattr(self.enum_type,"value",[])) ]
- @property
- def enum_values_str(self)->str:
- return ', '.join(self.enum_values)
- @property
- def fields(self)->List[FieldDescriptor]:
- return getattr(self.descriptor,"fields",getattr(getattr(self.descriptor,"message_type",None),"fields",None))
- @property
- def _default_value(self):
- if 'default_value' in self.options:
- return self.options['default_value']
- if self.type in [FieldDescriptorProto.TYPE_INT32, FieldDescriptorProto.TYPE_INT64,
- FieldDescriptorProto.TYPE_UINT32, FieldDescriptorProto.TYPE_UINT64,
- FieldDescriptorProto.TYPE_SINT32, FieldDescriptorProto.TYPE_SINT64,
- FieldDescriptorProto.TYPE_FIXED32, FieldDescriptorProto.TYPE_FIXED64,
- FieldDescriptorProto.TYPE_SFIXED32, FieldDescriptorProto.TYPE_SFIXED64]:
- return 0
- elif self.type in [FieldDescriptorProto.TYPE_FLOAT, FieldDescriptorProto.TYPE_DOUBLE]:
- return 0.0
- elif self.type == FieldDescriptorProto.TYPE_BOOL:
- return False
- elif self.type in [FieldDescriptorProto.TYPE_STRING, FieldDescriptorProto.TYPE_BYTES]:
- return ""
- elif self.is_enum:
- return self.enum_values[0] if self.enum_values else 0
- @property
- def detached_leading_comments(self)->str:
- return self.comments["leading"] if "detached" in self.comments else ""
- @property
- def leading_comment(self)->str:
- return self.comments["leading"] if "leading" in self.comments else ""
- @property
- def trailing_comment(self)->str:
- return self.comments["trailing"] if "trailing" in self.comments else ""
- @property
- def is_enum(self):
- return self.type == FieldDescriptorProto.TYPE_ENUM
- @property
- def path(self) -> str:
- return self.descriptor.full_name
- @property
- def enum_name(self)-> str:
- return self.type_name.split('.', maxsplit=1)[-1]
- @property
- def repeated(self)->bool:
- return self.label== FieldDescriptor.LABEL_REPEATED
-
-
|