ProtocParser.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # !/usr/bin/env python
  2. from functools import partial
  3. import sys
  4. import json
  5. from typing import Callable, Dict, List
  6. import argparse
  7. from abc import ABC, abstractmethod
  8. import google.protobuf.descriptor_pool as descriptor_pool
  9. from google.protobuf import message_factory,message
  10. from google.protobuf.message_factory import GetMessageClassesForFiles
  11. from google.protobuf.compiler import plugin_pb2 as plugin
  12. from google.protobuf.descriptor import FieldDescriptor, Descriptor, FileDescriptor
  13. from google.protobuf.descriptor_pb2 import FileDescriptorProto, DescriptorProto, FieldDescriptorProto,FieldOptions
  14. from urllib import parse
  15. from ProtoElement import ProtoElement
  16. import logging
  17. logger = logging.getLogger(__name__)
  18. logging.basicConfig(
  19. level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  20. )
  21. class ProtocParser(ABC) :
  22. request:plugin.CodeGeneratorRequest
  23. response:plugin.CodeGeneratorResponse
  24. elements:List[ProtoElement] = []
  25. comments: Dict[str, str] = {}
  26. json_content = {}
  27. main_class_list:List[str] = []
  28. param_dict:Dict[str,str] = {}
  29. pool:descriptor_pool.DescriptorPool
  30. factory:message_factory
  31. @abstractmethod
  32. def render(self,element: ProtoElement):
  33. pass
  34. @abstractmethod
  35. def get_name(self)->str:
  36. pass
  37. @abstractmethod
  38. # def start_element(self,element:ProtoElement):
  39. # logger.debug(f'START Processing ELEMENT {element.path}')
  40. # @abstractmethod
  41. # def end_element(self,element:ProtoElement):
  42. # logger.debug(f'END Processing ELEMENT {element.path}')
  43. @abstractmethod
  44. def end_message(self,classElement:ProtoElement):
  45. logger.debug(f'END Processing MESSAGE {classElement.name}')
  46. @abstractmethod
  47. def start_message(self,classElement:ProtoElement) :
  48. logger.debug(f'START Processing MESSAGE {classElement.name}')
  49. @abstractmethod
  50. def start_file(self,file:FileDescriptor) :
  51. logger.debug(f'START Processing file {file.name}')
  52. @abstractmethod
  53. def end_file(self,file:FileDescriptor) :
  54. logger.debug(f'END Processing file {file.name}')
  55. def __init__(self,data):
  56. self.request = plugin.CodeGeneratorRequest.FromString(data)
  57. self.response = plugin.CodeGeneratorResponse()
  58. logger.info(f'Received ${self.get_name()} parameter(s): {self.request.parameter}')
  59. params = self.request.parameter.split(',')
  60. self.param_dict = {p.split('=')[0]: parse.unquote(p.split('=')[1]) for p in params if '=' in p}
  61. if not 'const_prefix' in self.param_dict:
  62. self.param_dict['const_prefix'] = ""
  63. logger.warn("No option passed for const_prefix. No prefix will be used for option init_from_mac")
  64. self.main_class_list = self.get_arg(name= 'main_class',split=True,split_char='!')
  65. if 'path' in self.param_dict:
  66. self.param_dict['path'] = self.param_dict['path'].split('?')
  67. for p in self.param_dict['path']:
  68. logger.debug(f'Adding to path: {p}')
  69. sys.path.append(p)
  70. import customoptions_pb2 as custom__options__pb2
  71. def get_arg(self,name:str,default=None,split:bool=False,split_char:str=';'):
  72. result = self.param_dict.get(name, default)
  73. if result and len(result) == 0:
  74. if not default:
  75. logger.error(f'Plugin parameter {name} not found')
  76. result = None
  77. else:
  78. result = default
  79. logger.warn(f'Plugin parameter {name} not found. Defaulting to {str(default)}')
  80. if split and result:
  81. result = result.split(split_char)
  82. logger.debug(f'Returning argument {name}={str(result)}')
  83. return result
  84. def get_name_attr(self,proto_element):
  85. attributes = ['package','name']
  86. for att in attributes:
  87. if hasattr(proto_element, att):
  88. return att
  89. return None
  90. def interpret_path(self,path, proto_element):
  91. if not path:
  92. if hasattr(proto_element,"name"):
  93. return proto_element.name
  94. else:
  95. return ''
  96. # Get the next path element
  97. path_elem = path[0]
  98. name_att = self.get_name_attr(proto_element)
  99. if name_att:
  100. elem_name = getattr(proto_element, name_att)
  101. elem_sep = '.'
  102. else:
  103. elem_name = ''
  104. elem_sep = ''
  105. # Ensure the proto_element has a DESCRIPTOR attribute
  106. if hasattr(proto_element, 'DESCRIPTOR'):
  107. # Use the DESCRIPTOR to access field information
  108. descriptor = proto_element.DESCRIPTOR
  109. # Get the field name from the descriptor
  110. try:
  111. field = descriptor.fields_by_number[path_elem]
  112. except:
  113. return None
  114. field_name = field.name
  115. field_name = field_name.lower().replace('_field_number', '')
  116. # Access the field if it exists
  117. if field_name == "extension" :
  118. return field_name
  119. elif hasattr(proto_element, field_name):
  120. next_element = getattr(proto_element, field_name)
  121. if isinstance(next_element,list):
  122. # If the next element is a list, use the next path element as an index
  123. return f'{elem_name}{elem_sep}{self.interpret_path(path[1:], next_element[path[1]])}'
  124. else:
  125. # If it's not a list, just continue with the next path element
  126. return f'{elem_name}{elem_sep}{self.interpret_path(path[1:], next_element)}'
  127. else:
  128. return f'{elem_name}{elem_sep}{self.interpret_path(path[1:], proto_element[path_elem])}'
  129. # If the path cannot be interpreted, return None or raise an error
  130. return None
  131. def extract_comments(self,proto_file: FileDescriptorProto):
  132. for location in proto_file.source_code_info.location:
  133. # The path is a sequence of integers identifying the syntactic location
  134. path = tuple(location.path)
  135. leading_comments = location.leading_comments.strip()
  136. trailing_comments = location.trailing_comments.strip()
  137. if len(location.leading_detached_comments)>0:
  138. logger.debug('found detached comments')
  139. leading_detached_comments = '\r\n'.join(location.leading_detached_comments)
  140. if len(leading_comments) == 0 and len(trailing_comments) == 0 and len(leading_detached_comments) == 0:
  141. continue
  142. # Interpret the path and map it to a specific element
  143. # This is where you'll need to add logic based on your protobuf structure
  144. element_identifier = self.interpret_path(path, proto_file)
  145. if element_identifier is not None:
  146. self.comments[f"{element_identifier}.leading"] = leading_comments
  147. self.comments[f"{element_identifier}.trailing"] = trailing_comments
  148. self.comments[f"{element_identifier}.detached"] = leading_detached_comments
  149. def get_comments(self,field: FieldDescriptorProto, proto_file: FileDescriptorProto,message: DescriptorProto):
  150. if hasattr(field,'name') :
  151. name = getattr(field,'name')
  152. commentspath = f"{proto_file.package}.{message.name}.{name}"
  153. if commentspath in self.comments:
  154. return commentspath,self.comments[commentspath]
  155. return None,None
  156. def get_nested_message(self, field: FieldDescriptorProto, proto_file: FileDescriptorProto):
  157. # Handle nested message types
  158. if field.type != FieldDescriptorProto.TYPE_MESSAGE:
  159. return None
  160. nested_message_name = field.type_name.split('.')[-1]
  161. # logger.debug(f'Looking for {field.type_name} ({nested_message_name}) in {nested_list}')
  162. nested_message= next((m for m in proto_file.message_type if m.name == nested_message_name), None)
  163. if not nested_message:
  164. # logger.debug(f'Type {nested_message_name} was not found in file {proto_file.name}. Checking in processed list: {processed_list}')
  165. nested_message = next((m for m in self.elements if m.name == nested_message_name), None)
  166. if not nested_message:
  167. logger.error(f'Could not locate message class {field.type_name} ({nested_message_name})')
  168. return nested_message
  169. def process_message(self,message: ProtoElement, parent:ProtoElement = None )->ProtoElement:
  170. if not message:
  171. return
  172. if not message.fields:
  173. logger.warn(f"{message.path} doesn't have fields!")
  174. return
  175. for field in message.fields:
  176. element = ProtoElement(
  177. parent=message,
  178. descriptor=field
  179. )
  180. logging.debug(f'Element: {element.path}')
  181. if getattr(field,"message_type",None):
  182. self.process_message(element,message)
  183. @property
  184. def packages(self)->List[str]:
  185. return list(set([proto_file.package for proto_file in self.request.proto_file if proto_file.package]))
  186. @property
  187. def file_set(self)->List[FileDescriptor]:
  188. return list(set([ self.pool.FindMessageTypeByName(message).file for message in self.main_class_list if self.pool.FindMessageTypeByName(message)]))
  189. @property
  190. def proto_files(self)->List[FileDescriptorProto]:
  191. return list(
  192. proto_file for proto_file in self.request.proto_file if
  193. not proto_file.name.startswith("google/")
  194. and not proto_file.name.startswith("nanopb")
  195. and not proto_file.package.startswith("google.protobuf")
  196. )
  197. def get_main_messages_from_file(self,fileDescriptor:FileDescriptor)->List[Descriptor]:
  198. return [message for name,message in fileDescriptor.message_types_by_name.items() if message.full_name in self.main_class_list]
  199. def process(self) -> None:
  200. if len(self.proto_files) == 0:
  201. logger.error('No protocol buffer file selected for processing')
  202. return
  203. self.setup()
  204. logger.info(f'Processing message(s) {", ".join([name for name in self.main_class_list ])}')
  205. for fileObj in self.file_set :
  206. self.start_file(fileObj)
  207. for message in self.get_main_messages_from_file(fileObj):
  208. element = ProtoElement( descriptor=message )
  209. self.start_message(element)
  210. self.process_message(element)
  211. self.end_message(element)
  212. self.end_file(fileObj)
  213. sys.stdout.buffer.write(self.response.SerializeToString())
  214. def setup(self):
  215. for proto_file in self.proto_files:
  216. logger.debug(f"Extracting comments from : {proto_file.name}")
  217. self.extract_comments(proto_file)
  218. self.pool = descriptor_pool.DescriptorPool()
  219. self.factory = message_factory.MessageFactory(self.pool)
  220. for proto_file in self.request.proto_file:
  221. logger.debug(f'Adding {proto_file.name} to pool')
  222. self.pool.Add(proto_file)
  223. self.messages = GetMessageClassesForFiles([f.name for f in self.request.proto_file], self.pool)
  224. ProtoElement.set_pool(self.pool)
  225. ProtoElement.set_render(self.render)
  226. ProtoElement.set_logger(logger)
  227. ProtoElement.set_comments_base(self.comments)
  228. ProtoElement.set_prototypes(self.messages)
  229. @property
  230. def main_messages(self)->List[ProtoElement]:
  231. return [ele for ele in self.elements if ele.main_message ]
  232. def get_message_descriptor(self, name) -> Descriptor:
  233. for package in self.packages:
  234. qualified_name = f'{package}.{name}' if package else name
  235. try:
  236. descriptor = self.pool.FindMessageTypeByName(qualified_name)
  237. if descriptor:
  238. return descriptor
  239. except:
  240. pass
  241. return None
  242. @classmethod
  243. def get_data(cls):
  244. parser = argparse.ArgumentParser(description='Process protobuf and JSON files.')
  245. parser.add_argument('--source', help='Python source file', default=None)
  246. args = parser.parse_args()
  247. if args.source:
  248. logger.info(f'Loading request data from {args.source}')
  249. with open(args.source, 'rb') as file:
  250. data = file.read()
  251. else:
  252. data = sys.stdin.buffer.read()
  253. return data