ProtocParser.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. #!/opt/esp/python_env/idf4.4_py3.8_env/bin/python
  2. from functools import partial
  3. import sys
  4. import json
  5. from typing import Callable, Dict, List
  6. import argparse
  7. from abc import ABC, abstractmethod
  8. import google.protobuf.descriptor_pool as descriptor_pool
  9. from google.protobuf import message_factory,message
  10. from google.protobuf.message_factory import GetMessageClassesForFiles
  11. from google.protobuf.compiler import plugin_pb2 as plugin
  12. from google.protobuf.descriptor import FieldDescriptor, Descriptor, FileDescriptor
  13. from google.protobuf.descriptor_pb2 import FileDescriptorProto, DescriptorProto, FieldDescriptorProto,FieldOptions
  14. from urllib import parse
  15. from ProtoElement import ProtoElement
  16. import logging
  17. logger = logging.getLogger(__name__)
  18. logging.basicConfig(
  19. level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  20. )
  21. class ProtocParser(ABC) :
  22. request:plugin.CodeGeneratorRequest
  23. response:plugin.CodeGeneratorResponse
  24. elements:List[ProtoElement] = []
  25. comments: Dict[str, str] = {}
  26. positions={}
  27. json_content = {}
  28. main_class_list:List[str] = []
  29. param_dict:Dict[str,str] = {}
  30. pool:descriptor_pool.DescriptorPool
  31. factory:message_factory
  32. message_type_names:set = set()
  33. @abstractmethod
  34. def render(self,element: ProtoElement):
  35. pass
  36. @abstractmethod
  37. def get_name(self)->str:
  38. pass
  39. @abstractmethod
  40. # def start_element(self,element:ProtoElement):
  41. # logger.debug(f'START Processing ELEMENT {element.path}')
  42. # @abstractmethod
  43. # def end_element(self,element:ProtoElement):
  44. # logger.debug(f'END Processing ELEMENT {element.path}')
  45. @abstractmethod
  46. def end_message(self,classElement:ProtoElement):
  47. logger.debug(f'END Processing MESSAGE {classElement.name}')
  48. @abstractmethod
  49. def start_message(self,classElement:ProtoElement) :
  50. logger.debug(f'START Processing MESSAGE {classElement.name}')
  51. @abstractmethod
  52. def start_file(self,file:FileDescriptor) :
  53. logger.debug(f'START Processing file {file.name}')
  54. @abstractmethod
  55. def end_file(self,file:FileDescriptor) :
  56. logger.debug(f'END Processing file {file.name}')
  57. def __init__(self,data):
  58. self.request = plugin.CodeGeneratorRequest.FromString(data)
  59. self.response = plugin.CodeGeneratorResponse()
  60. logger.debug(f'Received ${self.get_name()} parameter(s): {self.request.parameter}')
  61. params = self.request.parameter.split(',')
  62. self.param_dict = {p.split('=')[0]: parse.unquote(p.split('=')[1]) for p in params if '=' in p}
  63. if not 'const_prefix' in self.param_dict:
  64. self.param_dict['const_prefix'] = ""
  65. logger.warn("No option passed for const_prefix. No prefix will be used for option init_from_mac")
  66. self.main_class_list = self.get_arg(name= 'main_class',split=True,split_char='!')
  67. if 'path' in self.param_dict:
  68. self.param_dict['path'] = self.param_dict['path'].split('?')
  69. for p in self.param_dict['path']:
  70. logger.debug(f'Adding to path: {p}')
  71. sys.path.append(p)
  72. import customoptions_pb2 as custom__options__pb2
  73. def get_arg(self,name:str,default=None,split:bool=False,split_char:str=';'):
  74. result = self.param_dict.get(name, default)
  75. if result and len(result) == 0:
  76. if not default:
  77. logger.error(f'Plugin parameter {name} not found')
  78. result = None
  79. else:
  80. result = default
  81. logger.warn(f'Plugin parameter {name} not found. Defaulting to {str(default)}')
  82. if split and result:
  83. result = result.split(split_char)
  84. logger.debug(f'Returning argument {name}={str(result)}')
  85. return result
  86. def get_name_attr(self,proto_element):
  87. attributes = ['package','name']
  88. for att in attributes:
  89. if hasattr(proto_element, att):
  90. return att
  91. return None
  92. def interpret_path(self,path, proto_element):
  93. if not path:
  94. if hasattr(proto_element,"name"):
  95. return proto_element.name
  96. else:
  97. return ''
  98. # Get the next path element
  99. path_elem = path[0]
  100. name_att = self.get_name_attr(proto_element)
  101. if name_att:
  102. elem_name = getattr(proto_element, name_att)
  103. elem_sep = '.'
  104. else:
  105. elem_name = ''
  106. elem_sep = ''
  107. # Ensure the proto_element has a DESCRIPTOR attribute
  108. if hasattr(proto_element, 'DESCRIPTOR'):
  109. # Use the DESCRIPTOR to access field information
  110. descriptor = proto_element.DESCRIPTOR
  111. # Get the field name from the descriptor
  112. try:
  113. field = descriptor.fields_by_number[path_elem]
  114. except:
  115. return None
  116. field_name = field.name
  117. field_name = field_name.lower().replace('_field_number', '')
  118. # Access the field if it exists
  119. if field_name == "extension" :
  120. return field_name
  121. elif hasattr(proto_element, field_name):
  122. next_element = getattr(proto_element, field_name)
  123. if isinstance(next_element,list):
  124. # If the next element is a list, use the next path element as an index
  125. return f'{elem_name}{elem_sep}{self.interpret_path(path[1:], next_element[path[1]])}'
  126. else:
  127. # If it's not a list, just continue with the next path element
  128. return f'{elem_name}{elem_sep}{self.interpret_path(path[1:], next_element)}'
  129. else:
  130. return f'{elem_name}{elem_sep}{self.interpret_path(path[1:], proto_element[path_elem])}'
  131. # If the path cannot be interpreted, return None or raise an error
  132. return None
  133. def extract_comments(self,proto_file: FileDescriptorProto):
  134. for location in proto_file.source_code_info.location:
  135. # The path is a sequence of integers identifying the syntactic location
  136. path = tuple(location.path)
  137. leading_comments = location.leading_comments.strip()
  138. trailing_comments = location.trailing_comments.strip()
  139. if len(location.leading_detached_comments)>0:
  140. logger.debug('found detached comments')
  141. leading_detached_comments = '\r\n'.join(location.leading_detached_comments)
  142. if len(leading_comments) == 0 and len(trailing_comments) == 0 and len(leading_detached_comments) == 0:
  143. continue
  144. # Interpret the path and map it to a specific element
  145. # This is where you'll need to add logic based on your protobuf structure
  146. element_identifier = self.interpret_path(path, proto_file)
  147. if element_identifier is not None:
  148. self.comments[f"{element_identifier}.leading"] = leading_comments
  149. self.comments[f"{element_identifier}.trailing"] = trailing_comments
  150. self.comments[f"{element_identifier}.detached"] = leading_detached_comments
  151. def extract_positions(self, proto_file: FileDescriptorProto):
  152. for location in proto_file.source_code_info.location:
  153. # The path is a sequence of integers identifying the syntactic location
  154. path = tuple(location.path)
  155. # Interpret the path and map it to a specific element
  156. element_identifier = self.interpret_path(path, proto_file)
  157. if element_identifier is not None and not element_identifier.endswith('.'):
  158. # Extracting span information for position
  159. if len(location.span) >= 3: # Ensure span has at least start line, start column, and end line
  160. start_line, start_column, end_line = location.span[:3]
  161. # Adjusting for 1-indexing and storing the position
  162. self.positions[element_identifier] = (start_line + 1, start_column + 1, end_line + 1)
  163. def get_comments(self,field: FieldDescriptorProto, proto_file: FileDescriptorProto,message: DescriptorProto):
  164. if hasattr(field,'name') :
  165. name = getattr(field,'name')
  166. commentspath = f"{proto_file.package}.{message.name}.{name}"
  167. if commentspath in self.comments:
  168. return commentspath,self.comments[commentspath]
  169. return None,None
  170. def get_nested_message(self, field: FieldDescriptorProto, proto_file: FileDescriptorProto):
  171. # Handle nested message types
  172. if field.type != FieldDescriptorProto.TYPE_MESSAGE:
  173. return None
  174. nested_message_name = field.type_name.split('.')[-1]
  175. # logger.debug(f'Looking for {field.type_name} ({nested_message_name}) in {nested_list}')
  176. nested_message= next((m for m in proto_file.message_type if m.name == nested_message_name), None)
  177. if not nested_message:
  178. # logger.debug(f'Type {nested_message_name} was not found in file {proto_file.name}. Checking in processed list: {processed_list}')
  179. nested_message = next((m for m in self.elements if m.name == nested_message_name), None)
  180. if not nested_message:
  181. logger.error(f'Could not locate message class {field.type_name} ({nested_message_name})')
  182. return nested_message
  183. def process_message(self,message: ProtoElement, parent:ProtoElement = None )->ProtoElement:
  184. if not message:
  185. return
  186. if not message.fields:
  187. logger.warn(f"{message.path} doesn't have fields!")
  188. return
  189. for field in message.fields:
  190. element = ProtoElement(
  191. parent=message,
  192. descriptor=field
  193. )
  194. logging.debug(f'Element: {element.path}')
  195. if getattr(field,"message_type",None):
  196. self.process_message(element,message)
  197. @property
  198. def packages(self)->List[str]:
  199. return list(set([proto_file.package for proto_file in self.request.proto_file if proto_file.package]))
  200. @property
  201. def file_set(self)->List[FileDescriptor]:
  202. file_set = []
  203. missing_messages = []
  204. for message in self.main_class_list:
  205. try:
  206. message_descriptor = self.pool.FindMessageTypeByName(message)
  207. if message_descriptor:
  208. file_set.append(message_descriptor.file)
  209. else:
  210. missing_messages.append(message)
  211. except Exception as e:
  212. missing_messages.append(message)
  213. if missing_messages:
  214. sortedstring="\n".join(sorted(self.message_type_names))
  215. logger.error(f'Error retrieving message definitions for: {", ".join(missing_messages)}. Valid messages are: \n{sortedstring}')
  216. raise Exception(f"Invalid message(s) {missing_messages}")
  217. # Deduplicate file descriptors
  218. unique_file_set = list(set(file_set))
  219. return unique_file_set
  220. @property
  221. def proto_files(self)->List[FileDescriptorProto]:
  222. return list(
  223. proto_file for proto_file in self.request.proto_file if
  224. not proto_file.name.startswith("google/")
  225. and not proto_file.name.startswith("nanopb")
  226. and not proto_file.package.startswith("google.protobuf")
  227. )
  228. def get_main_messages_from_file(self,fileDescriptor:FileDescriptor)->List[Descriptor]:
  229. return [message for name,message in fileDescriptor.message_types_by_name.items() if message.full_name in self.main_class_list]
  230. def process(self) -> None:
  231. if len(self.proto_files) == 0:
  232. logger.error('No protocol buffer file selected for processing')
  233. return
  234. self.setup()
  235. logger.info(f'Processing message(s) {", ".join([name for name in self.main_class_list ])}')
  236. try:
  237. for fileObj in self.file_set :
  238. self.start_file(fileObj)
  239. for message in self.get_main_messages_from_file(fileObj):
  240. element = ProtoElement( descriptor=message )
  241. self.start_message(element)
  242. self.process_message(element)
  243. self.end_message(element)
  244. self.end_file(fileObj)
  245. sys.stdout.buffer.write(self.response.SerializeToString())
  246. except Exception as e:
  247. # Log the error and exit gracefully
  248. error_message = str(e)
  249. logger.error(f'Failed to process protocol buffer files: {error_message}')
  250. sys.stderr.write(error_message + '\n')
  251. sys.exit(1) # Exit with a non-zero status code to indicate failure
  252. def setup(self):
  253. for proto_file in self.proto_files:
  254. logger.debug(f"Extracting comments from : {proto_file.name}")
  255. self.extract_positions(proto_file)
  256. self.extract_comments(proto_file)
  257. self.pool = descriptor_pool.DescriptorPool()
  258. self.factory = message_factory.MessageFactory(self.pool)
  259. for proto_file in self.request.proto_file:
  260. logger.debug(f'Adding {proto_file.name} to pool')
  261. self.pool.Add(proto_file)
  262. # Iterate over all message types in the proto file and add them to the list
  263. for message_type in proto_file.message_type:
  264. # Assuming proto_file.message_type gives you message descriptors or similar
  265. # You may need to adjust based on how proto_file is structured
  266. self.message_type_names.add(f"{proto_file.package}.{message_type.name}")
  267. self.messages = GetMessageClassesForFiles([f.name for f in self.request.proto_file], self.pool)
  268. ProtoElement.set_pool(self.pool)
  269. ProtoElement.set_render(self.render)
  270. ProtoElement.set_logger(logger)
  271. ProtoElement.set_comments_base(self.comments)
  272. ProtoElement.set_positions_base(self.positions)
  273. ProtoElement.set_prototypes(self.messages)
  274. @property
  275. def main_messages(self)->List[ProtoElement]:
  276. return [ele for ele in self.elements if ele.main_message ]
  277. def get_message_descriptor(self, name) -> Descriptor:
  278. for package in self.packages:
  279. qualified_name = f'{package}.{name}' if package else name
  280. try:
  281. descriptor = self.pool.FindMessageTypeByName(qualified_name)
  282. if descriptor:
  283. return descriptor
  284. except:
  285. pass
  286. return None
  287. @classmethod
  288. def get_data(cls):
  289. parser = argparse.ArgumentParser(description='Process protobuf and JSON files.')
  290. parser.add_argument('--source', help='Python source file', default=None)
  291. args = parser.parse_args()
  292. if args.source:
  293. logger.info(f'Loading request data from {args.source}')
  294. with open(args.source, 'rb') as file:
  295. data = file.read()
  296. else:
  297. data = sys.stdin.buffer.read()
  298. return data