protoc-gen-defaults.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. #!/opt/esp/python_env/idf4.4_py3.8_env/bin/python
  2. import os
  3. import logging
  4. import json
  5. from pathlib import Path
  6. from typing import Dict, List
  7. from google.protobuf.compiler import plugin_pb2 as plugin
  8. from google.protobuf.message_factory import GetMessageClass
  9. from google.protobuf.descriptor_pb2 import FileDescriptorProto, DescriptorProto, FieldDescriptorProto,FieldOptions
  10. from google.protobuf.descriptor import FieldDescriptor, Descriptor, FileDescriptor
  11. from ProtoElement import ProtoElement
  12. from ProtocParser import ProtocParser
  13. from google.protobuf.json_format import Parse
  14. logger = logging.getLogger(__name__)
  15. logging.basicConfig(
  16. level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  17. )
  18. def is_iterable(obj):
  19. try:
  20. iter(obj)
  21. return True
  22. except TypeError:
  23. return False
  24. class BinDefaultsParser(ProtocParser) :
  25. def start_message(self,message:ProtoElement) :
  26. super().start_message(message)
  27. def end_message(self,message:ProtoElement):
  28. super().end_message(message)
  29. self.has_error = False
  30. default_structure = message.render()
  31. if not default_structure:
  32. logger.warning(f'No default values for {message.name}')
  33. return
  34. respfile = self.response.file.add()
  35. outfilename = f'{message.name.lower()}.bin'
  36. output_directory = os.path.join(self.param_dict.get('defaultspath', '.'),"defaults")
  37. output_path = os.path.join(output_directory, outfilename)
  38. os.makedirs(output_directory, exist_ok=True)
  39. with open(output_path, 'wb') as bin_file:
  40. res = default_structure.SerializeToString()
  41. bin_file.write(res)
  42. logger.info(f'Wrote {bin_file.name}')
  43. respfile.name = f'{outfilename}.gen'
  44. logger.info(f"Creating binary file for defaults: {respfile.name}")
  45. respfile.content = f'Content written to {respfile.name}'
  46. def start_file(self,file:FileDescriptor) :
  47. super().start_file(file)
  48. def end_file(self,file:ProtoElement) :
  49. super().end_file(file)
  50. def get_name(self)->str:
  51. return 'protoc_plugin_defaults'
  52. def add_comment_if_exists(element, comment_type: str, path: str) -> dict:
  53. comment = getattr(element, f"{comment_type}_comment", "").strip()
  54. return {f"__{comment_type}_{path}": comment} if comment else {}
  55. def repeated_render(self,element:ProtoElement,obj:any):
  56. return [obj] if element.repeated else obj
  57. def render(self,element: ProtoElement):
  58. options = element.options['cust_field'] if 'cust_field' in element.options else None
  59. if len(element.childs)>0:
  60. oneof = getattr(element.descriptor,'containing_oneof',None)
  61. if oneof:
  62. # we probably shouldn't set default values here
  63. pass
  64. has_render = False
  65. for child in element.childs:
  66. try:
  67. rendered = child.render()
  68. if rendered:
  69. has_render = True
  70. # try:
  71. if child.type == FieldDescriptor.TYPE_MESSAGE:
  72. target_field = getattr(element.message_instance, child.name)
  73. if child.label == FieldDescriptor.LABEL_REPEATED:
  74. # If the field is repeated, iterate over the array and add each instance
  75. if is_iterable(rendered) and not isinstance(rendered, str):
  76. for instance in rendered:
  77. target_field.add().CopyFrom(instance)
  78. else:
  79. target_field.add().CopyFrom(rendered)
  80. else:
  81. # For non-repeated fields, use CopyFrom
  82. target_field.CopyFrom(rendered)
  83. elif child.repeated:
  84. try:
  85. getattr(element.message_instance,child.name).extend(rendered)
  86. except:
  87. getattr(element.message_instance,child.name).extend( [rendered])
  88. else:
  89. setattr(element.message_instance,child.name,rendered)
  90. # except:
  91. # logger.error(f'Unable to assign value from {child.fullname} to {element.fullname}')
  92. element.message_instance.SetInParent()
  93. except Exception as e:
  94. logger.error(f'{child.proto_file_line} Rendering default values failed for {child.name} of {child.path} in file {child.file.name}: {e}')
  95. raise e
  96. if getattr(options, 'v_msg', None):
  97. has_render = True
  98. v_msg = getattr(options, 'v_msg', None)
  99. try:
  100. if element.repeated:
  101. # Create a list to hold the message instances
  102. message_instances = []
  103. # Parse each element of the JSON array
  104. for json_element in json.loads(v_msg):
  105. new_instance = element.new_message_instance
  106. Parse(json.dumps(json_element), new_instance)
  107. message_instances.append(new_instance)
  108. element.message_instance.SetInParent()
  109. return message_instances
  110. # Copy each instance to the appropriate field in the parent message
  111. # repeated_field = getattr(element.message_instance, child.name)
  112. # for instance in message_instances:
  113. # repeated_field.add().CopyFrom(instance)
  114. else:
  115. # If the field is not repeated, parse the JSON string directly
  116. Parse(v_msg, element.message_instance)
  117. element.message_instance.SetInParent()
  118. except Exception as e:
  119. # Handle parsing errors, e.g., log them
  120. logger.error(f"{element.proto_file_line} Error parsing json default value {v_msg} as JSON. {e}")
  121. raise e
  122. if not has_render:
  123. return None
  124. else:
  125. default_value = None
  126. msg_options = element.options['cust_msg'] if 'cust_msg' in element.options else None
  127. init_from_mac = getattr(options,'init_from_mac', False) or getattr(msg_options,'init_from_mac', False)
  128. global_name = getattr(options,'global_name', None)
  129. const_prefix = getattr(options,'const_prefix', self.param_dict['const_prefix'])
  130. if init_from_mac:
  131. default_value = f'{const_prefix}@@init_from_mac@@'
  132. elif element.descriptor.cpp_type == FieldDescriptor.CPPTYPE_STRING:
  133. default_value = default_value = getattr(options,'v_string', None)
  134. elif element.descriptor.cpp_type == FieldDescriptor.CPPTYPE_ENUM:
  135. if options is not None:
  136. try:
  137. enum_value = getattr(options,'v_enum', None) or getattr(options,'v_string', None)
  138. if enum_value is not None:
  139. default_value = element.enum_values.index(enum_value)
  140. except:
  141. raise ValueError(f'Invalid default value {default_value} for {element.path}')
  142. # Handling integer types
  143. elif element.descriptor.cpp_type in [FieldDescriptor.CPPTYPE_INT32, FieldDescriptor.CPPTYPE_INT64,
  144. FieldDescriptor.CPPTYPE_UINT32, FieldDescriptor.CPPTYPE_UINT64]:
  145. if element.descriptor.cpp_type in [FieldDescriptor.CPPTYPE_INT32, FieldDescriptor.CPPTYPE_INT64]:
  146. default_value = getattr(options, 'v_int32', getattr(options, 'v_int64', None))
  147. else:
  148. default_value = getattr(options, 'v_uint32', getattr(options, 'v_uint64', None))
  149. if default_value is not None:
  150. int_value= int(default_value)
  151. if element.descriptor.cpp_type in [FieldDescriptor.CPPTYPE_UINT32, FieldDescriptor.CPPTYPE_UINT64] and int_value < 0:
  152. raise ValueError(f"Negative value for unsigned int type trying to assign {element.path} = {default_value}")
  153. default_value = int_value
  154. # Handling float and double types
  155. elif element.descriptor.cpp_type in [FieldDescriptor.CPPTYPE_DOUBLE, FieldDescriptor.CPPTYPE_FLOAT]:
  156. default_value = getattr(options, 'v_double', getattr(options, 'v_float', None))
  157. if default_value is not None:
  158. float_value = float(default_value)
  159. if '.' not in str(default_value):
  160. raise ValueError(f"Integer string for float/double type trying to assign {element.path} = {default_value}")
  161. default_value = float_value
  162. # Handling boolean type
  163. elif element.descriptor.cpp_type == FieldDescriptor.CPPTYPE_BOOL:
  164. if options is not None:
  165. default_value = getattr(options, 'v_bool', False)
  166. if isinstance(default_value, str):
  167. if default_value.lower() in ['true', 'false']:
  168. default_value = default_value.lower() == 'true'
  169. else:
  170. raise ValueError(f'Invalid boolean value trying to assign {element.path} = {default_value}')
  171. # Handling bytes type
  172. elif element.descriptor.cpp_type == FieldDescriptor.CPPTYPE_BYTES:
  173. default_value = getattr(options, 'v_bytes', b'')
  174. elif element.descriptor.cpp_type == FieldDescriptor.TYPE_MESSAGE:
  175. pass
  176. if default_value is not None:
  177. element.message_instance.SetInParent()
  178. return self.repeated_render(element, default_value)
  179. else:
  180. return None
  181. return element.message_instance
  182. if __name__ == '__main__':
  183. data = ProtocParser.get_data()
  184. logger.info(f"Generating binary files for defaults")
  185. protocParser:BinDefaultsParser = BinDefaultsParser(data)
  186. protocParser.process()
  187. logger.debug('Done generating JSON file(s)')