2
0

ProtoElement.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from __future__ import annotations
  2. from functools import partial
  3. from typing import ClassVar, List, Any, Dict
  4. from typing import Callable
  5. from google.protobuf import message, descriptor_pb2
  6. from google.protobuf.descriptor import Descriptor, FieldDescriptor, FileDescriptor
  7. from google.protobuf.descriptor_pb2 import FieldDescriptorProto,DescriptorProto,EnumDescriptorProto
  8. import google.protobuf.descriptor_pool as descriptor_pool
  9. import logging
  10. import copy
  11. # import custom_options_pb2 as custom
  12. RendererType = Callable[['ProtoElement'], Dict]
  13. class ProtoElement:
  14. childs:List[ProtoElement]
  15. descriptor:Descriptor|FieldDescriptor
  16. comments: Dict[str,str]
  17. enum_type:EnumDescriptorProto
  18. _comments: Dict[str,str] ={}
  19. pool:descriptor_pool.DescriptorPool
  20. prototypes: dict[str, type[message.Message]]
  21. renderer:RendererType
  22. package:str
  23. file:FileDescriptor
  24. message:str
  25. _positions: Dict[str,tuple]
  26. position: tuple
  27. options:Dict[str,any]
  28. _message_instance:ClassVar
  29. @classmethod
  30. def set_prototypes(cls,prototypes:dict[str, type[message.Message]]):
  31. cls.prototypes = prototypes
  32. @classmethod
  33. def set_comments_base(cls,comments:Dict[str,str]):
  34. cls._comments = comments
  35. @classmethod
  36. def set_positions_base(cls,positions:Dict[str,tuple]):
  37. cls._positions = positions
  38. @classmethod
  39. def set_pool(cls,pool:descriptor_pool.DescriptorPool):
  40. cls.pool = pool
  41. @classmethod
  42. def set_logger(cls,logger = None):
  43. if not logger and not cls.logger:
  44. cls.logger = logging.getLogger(__name__)
  45. logging.basicConfig(
  46. level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  47. )
  48. elif logger:
  49. cls.logger = logger
  50. @classmethod
  51. def set_render(cls,render):
  52. cls.render_class = render
  53. def __init__(self, descriptor: Descriptor|FieldDescriptor, parent=None):
  54. ProtoElement.set_logger()
  55. self.descriptor = descriptor
  56. self.file = descriptor.file
  57. self.package = getattr(descriptor,"file",parent).package
  58. self.descriptorname = descriptor.name
  59. self.json_name = getattr(descriptor,'json_name','')
  60. self.type_name = getattr(descriptor,'type_name',descriptor.name)
  61. self.parent = parent
  62. self.fullname = descriptor.full_name
  63. self.type = getattr(descriptor,'type',FieldDescriptor.TYPE_MESSAGE)
  64. if self.type ==FieldDescriptor.TYPE_MESSAGE:
  65. try:
  66. self._message_instance = self.prototypes[self.descriptor.message_type.full_name]()
  67. # self.logger.debug(f'Found instance for {self.descriptor.message_type.full_name}')
  68. except:
  69. # self.logger.error(f'Could not find instance for {self.descriptor.full_name}')
  70. self._message_instance = self.prototypes[self.descriptor.full_name]()
  71. self.label = getattr(descriptor,'label',None)
  72. self.childs = []
  73. if descriptor.has_options:
  74. self.options = {descr.name: value for descr, value in descriptor.GetOptions().ListFields()}
  75. else:
  76. self.options = {}
  77. try:
  78. if descriptor.containing_type.has_options:
  79. self.options.update({descr.name: value for descr, value in descriptor.containing_type.GetOptions().ListFields()})
  80. except:
  81. pass
  82. self.render = partial(self.render_class, self)
  83. self.comments = {comment.split('.')[-1]:self._comments[comment] for comment in self._comments.keys() if comment.startswith(self.path)}
  84. self.position = self._positions.get(self.path)
  85. @property
  86. def cpp_type(self)->str:
  87. return f'{self.package}_{self.descriptor.containing_type.name}'
  88. @property
  89. def cpp_member(self)->str:
  90. return self.name
  91. @property
  92. def cpp_type_member_prefix(self)->str:
  93. return f'{self.cpp_type}_{self.cpp_member}'
  94. @property
  95. def cpp_type_member(self)->str:
  96. return f'{self.cpp_type}.{self.cpp_member}'
  97. @property
  98. def main_message(self)->bool:
  99. return self.parent == None
  100. @property
  101. def parent(self)->ProtoElement:
  102. return self._parent
  103. @parent.setter
  104. def parent(self,value:ProtoElement):
  105. self._parent = value
  106. if value:
  107. self._parent.childs.append(self)
  108. @property
  109. def root(self)->ProtoElement:
  110. return self if not self.parent else self.parent
  111. @property
  112. def enum_type(self)->EnumDescriptorProto:
  113. return self.descriptor.enum_type
  114. @property
  115. def cpp_root(self):
  116. return f'{self.cpp_type}_ROOT'
  117. @property
  118. def cpp_child(self):
  119. return f'{self.cpp_type}_CHILD'
  120. @property
  121. def proto_file_line(self):
  122. # Accessing file descriptor to get source code info, adjusted for proper context
  123. if self.position:
  124. start_line, start_column, end_line = self.position
  125. return f"{self.file.name}:{start_line}"
  126. else:
  127. return f"{self.file.name}"
  128. @property
  129. def message_instance(self):
  130. return getattr(self,'_message_instance',getattr(self.parent,'message_instance',None))
  131. @property
  132. def new_message_instance(self):
  133. if self.type == FieldDescriptor.TYPE_MESSAGE:
  134. try:
  135. # Try to create a new instance using the full name of the message type
  136. return self.prototypes[self.descriptor.message_type.full_name]()
  137. except KeyError:
  138. # If the above fails, use an alternative method to create a new instance
  139. # Log the error if necessary
  140. # self.logger.error(f'Could not find instance for {self.descriptor.full_name}')
  141. return self.prototypes[self.descriptor.full_name]()
  142. else:
  143. # Return None or raise an exception if the type is not a message
  144. return None
  145. @property
  146. def tree(self):
  147. childs = '->('+', '.join(c.tree for c in self.childs ) + ')' if len(self.childs)>0 else ''
  148. return f'{self.name}{childs}'
  149. @property
  150. def name(self):
  151. return self.descriptorname if len(self.descriptorname)>0 else self.parent.name if self.parent else self.package
  152. @property
  153. def enum_values(self)->List[str]:
  154. return [n.name for n in getattr(self.enum_type,"values",getattr(self.enum_type,"value",[])) ]
  155. @property
  156. def enum_values_str(self)->str:
  157. return ', '.join(self.enum_values)
  158. @property
  159. def fields(self)->List[FieldDescriptor]:
  160. return getattr(self.descriptor,"fields",getattr(getattr(self.descriptor,"message_type",None),"fields",None))
  161. @property
  162. def _default_value(self):
  163. if 'default_value' in self.options:
  164. return self.options['default_value']
  165. if self.type in [FieldDescriptorProto.TYPE_INT32, FieldDescriptorProto.TYPE_INT64,
  166. FieldDescriptorProto.TYPE_UINT32, FieldDescriptorProto.TYPE_UINT64,
  167. FieldDescriptorProto.TYPE_SINT32, FieldDescriptorProto.TYPE_SINT64,
  168. FieldDescriptorProto.TYPE_FIXED32, FieldDescriptorProto.TYPE_FIXED64,
  169. FieldDescriptorProto.TYPE_SFIXED32, FieldDescriptorProto.TYPE_SFIXED64]:
  170. return 0
  171. elif self.type in [FieldDescriptorProto.TYPE_FLOAT, FieldDescriptorProto.TYPE_DOUBLE]:
  172. return 0.0
  173. elif self.type == FieldDescriptorProto.TYPE_BOOL:
  174. return False
  175. elif self.type in [FieldDescriptorProto.TYPE_STRING, FieldDescriptorProto.TYPE_BYTES]:
  176. return ""
  177. elif self.is_enum:
  178. return self.enum_values[0] if self.enum_values else 0
  179. @property
  180. def detached_leading_comments(self)->str:
  181. return self.comments["leading"] if "detached" in self.comments else ""
  182. @property
  183. def leading_comment(self)->str:
  184. return self.comments["leading"] if "leading" in self.comments else ""
  185. @property
  186. def trailing_comment(self)->str:
  187. return self.comments["trailing"] if "trailing" in self.comments else ""
  188. @property
  189. def is_enum(self):
  190. return self.type == FieldDescriptorProto.TYPE_ENUM
  191. @property
  192. def path(self) -> str:
  193. return self.descriptor.full_name
  194. @property
  195. def enum_name(self)-> str:
  196. return self.type_name.split('.', maxsplit=1)[-1]
  197. @property
  198. def repeated(self)->bool:
  199. return self.label== FieldDescriptor.LABEL_REPEATED