Coverage for pyodmongo/engines/utils.py: 100%

78 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-27 14:31 +0000

1from pydantic import BaseModel 

2from bson import ObjectId 

3from ..models.id_model import Id 

4from ..models.db_model import MainBaseModel 

5from ..models.db_field_info import DbField 

6from ..models.db_decimal import DbDecimal 

7from ..services.reference_pipeline import resolve_reference_pipeline 

8from ..services.verify_subclasses import is_subclass 

9from decimal import Decimal 

10from bson import Decimal128 

11 

12 

13def consolidate_dict(obj: MainBaseModel, dct: dict, populate: bool): 

14 """ 

15 Recursively consolidates the attributes of a Pydantic model into a dictionary, 

16 handling nested models, list fields, and references appropriately based on the model's 

17 field specifications. 

18 

19 Args: 

20 obj (MainBaseModel): The PyODMongo model instance from which to extract data. 

21 dct (dict): The dictionary into which data should be consolidated. This dictionary 

22 is modified in-place. 

23 

24 Returns: 

25 dict: The modified dictionary with model data consolidated into it, including handling 

26 of MongoDB ObjectId conversions and nested structures. 

27 

28 Description: 

29 This function iterates over each field of the provided model instance, extracting the 

30 value and associated metadata. Depending on whether the field contains model data, 

31 is a list, or is a reference, the function appropriately processes the value to fit 

32 MongoDB storage requirements, including converting to ObjectIds where necessary. 

33 The function is recursive for nested models and lists of models. 

34 """ 

35 for field, field_info in obj.model_fields.items(): 

36 value = getattr(obj, field) 

37 try: 

38 db_field_info: DbField = getattr(obj.__class__, field) 

39 except AttributeError: 

40 if is_subclass(class_to_verify=obj.__class__, subclass=BaseModel): 

41 raise TypeError( 

42 f"The {obj.__class__.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class" 

43 ) 

44 alias = db_field_info.field_alias 

45 has_model_fields = db_field_info.has_model_fields 

46 is_list = db_field_info.is_list 

47 by_reference = db_field_info.by_reference and not populate 

48 if has_model_fields: 

49 if value is None: 

50 dct[alias] = None 

51 continue 

52 if not is_list: 

53 if by_reference: 

54 try: 

55 dct[alias] = ObjectId(value.id) 

56 except AttributeError: 

57 dct[alias] = ObjectId(value) 

58 else: 

59 dct[alias] = {} 

60 consolidate_dict(obj=value, dct=dct[alias], populate=populate) 

61 else: 

62 if by_reference: 

63 try: 

64 dct[alias] = [ObjectId(o.id) for o in value] 

65 except AttributeError: 

66 dct[alias] = [ObjectId(o) for o in value] 

67 else: 

68 dct[alias] = [] 

69 for v in value: 

70 obj_lst_elem = {} 

71 consolidate_dict(obj=v, dct=obj_lst_elem, populate=populate) 

72 dct[alias].append(obj_lst_elem) 

73 else: 

74 if db_field_info.field_type == Id and value is not None: 

75 if is_list: 

76 dct[alias] = [ObjectId(o) for o in value if ObjectId.is_valid(o)] 

77 else: 

78 dct[alias] = ObjectId(value) if ObjectId.is_valid(value) else value 

79 elif ( 

80 db_field_info.field_type == Decimal 

81 or db_field_info.field_type == DbDecimal 

82 ) and value is not None: 

83 if is_list: 

84 dct[alias] = [Decimal128(str(o)) for o in value] 

85 else: 

86 dct[alias] = Decimal128(str(value)) 

87 else: 

88 dct[alias] = value 

89 return dct 

90 

91 

92def _skip_and_limit_stages(current_page: int, docs_per_page: int): 

93 """ 

94 Add pagination stages to the aggregation pipeline. 

95 

96 Args: 

97 current_page (int): The current page number. 

98 docs_per_page (int): The number of documents per page. 

99 """ 

100 max_docs_per_page = 1000 

101 current_page = 1 if current_page <= 0 else current_page 

102 docs_per_page = ( 

103 max_docs_per_page if docs_per_page > max_docs_per_page else docs_per_page 

104 ) 

105 skip = (docs_per_page * current_page) - docs_per_page 

106 skip_stage = [{"$skip": skip}] 

107 limit_stage = [{"$limit": docs_per_page}] 

108 return skip_stage, limit_stage 

109 

110 

111def mount_base_pipeline( 

112 Model, 

113 query: dict, 

114 sort: dict, 

115 populate: bool, 

116 pipeline: list | None, 

117 populate_db_fields: list[DbField] | None, 

118 paginate: int, 

119 current_page: int, 

120 docs_per_page: int, 

121 is_find_one: bool, 

122): 

123 """ 

124 Constructs a basic MongoDB aggregation pipeline based on the provided query, sort, and 

125 model settings, optionally including reference population stages. 

126 

127 Args: 

128 Model (type[Model]): The model class defining the MongoDB collection and its aggregation logic. 

129 query (dict): MongoDB query dictionary to filter the documents. 

130 sort (dict): Dictionary defining the sorting order of the documents. 

131 populate (bool): Flag indicating whether to include reference population stages in the pipeline. 

132 

133 Returns: 

134 list: The MongoDB aggregation pipeline configured with match, sort, and optionally population stages. 

135 

136 Description: 

137 This function constructs a MongoDB aggregation pipeline using the provided model's 

138 internal pipeline stages and the specified query and sort parameters. If 'populate' is 

139 true, reference population stages defined in the model are included to resolve document 

140 references as part of the aggregation process. 

141 """ 

142 match_stage = [{"$match": query}] 

143 sort_stage = [{"$sort": sort}] if sort != {} else [] 

144 skip_stage = [] 

145 limit_stage = [{"$limit": 1}] if is_find_one else [] 

146 model_stage = Model._pipeline 

147 if pipeline: 

148 return match_stage + pipeline + sort_stage 

149 if model_stage != []: 

150 return match_stage + model_stage + sort_stage 

151 if paginate: 

152 skip_stage, limit_stage = _skip_and_limit_stages( 

153 current_page=current_page, docs_per_page=docs_per_page 

154 ) 

155 if populate: 

156 reference_stage = resolve_reference_pipeline( 

157 cls=Model, pipeline=[], populate_db_fields=populate_db_fields 

158 ) 

159 return match_stage + sort_stage + skip_stage + limit_stage + reference_stage 

160 else: 

161 return match_stage + sort_stage + skip_stage + limit_stage