Coverage for pyodmongo/engines/utils.py: 100%

76 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 15:48 +0000

1from pydantic import BaseModel 

2from bson import ObjectId 

3from ..models.id_model import Id 

4from ..models.db_model import MainBaseModel 

5from ..models.db_field_info import DbField 

6from ..models.db_decimal import DbDecimal 

7from ..services.reference_pipeline import resolve_reference_pipeline 

8from ..services.verify_subclasses import is_subclass 

9from decimal import Decimal 

10from bson import Decimal128 

11 

12 

13def consolidate_dict(obj: MainBaseModel, dct: dict, populate: bool): 

14 """ 

15 Recursively consolidates the attributes of a Pydantic model into a dictionary, 

16 handling nested models, list fields, and references appropriately based on the model's 

17 field specifications. 

18 

19 Args: 

20 obj (MainBaseModel): The PyODMongo model instance from which to extract data. 

21 dct (dict): The dictionary into which data should be consolidated. This dictionary 

22 is modified in-place. 

23 

24 Returns: 

25 dict: The modified dictionary with model data consolidated into it, including handling 

26 of MongoDB ObjectId conversions and nested structures. 

27 

28 Description: 

29 This function iterates over each field of the provided model instance, extracting the 

30 value and associated metadata. Depending on whether the field contains model data, 

31 is a list, or is a reference, the function appropriately processes the value to fit 

32 MongoDB storage requirements, including converting to ObjectIds where necessary. 

33 The function is recursive for nested models and lists of models. 

34 """ 

35 for field, field_info in obj.__class__.model_fields.items(): 

36 value = getattr(obj, field) 

37 try: 

38 db_field_info: DbField = getattr(obj.__class__, field) 

39 except AttributeError: 

40 if is_subclass(class_to_verify=obj.__class__, subclass=BaseModel): 

41 raise TypeError( 

42 f"The {obj.__class__.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class" 

43 ) 

44 alias = db_field_info.field_alias 

45 has_model_fields = db_field_info.has_model_fields 

46 is_list = db_field_info.is_list 

47 by_reference = db_field_info.by_reference and not populate 

48 if has_model_fields: 

49 if value is None: 

50 dct[alias] = None 

51 continue 

52 if not is_list: 

53 if by_reference: 

54 try: 

55 dct[alias] = ObjectId(value.id) 

56 except AttributeError: 

57 dct[alias] = ObjectId(value) 

58 else: 

59 dct[alias] = {} 

60 consolidate_dict(obj=value, dct=dct[alias], populate=populate) 

61 else: 

62 if by_reference: 

63 try: 

64 dct[alias] = [ObjectId(o.id) for o in value] 

65 except AttributeError: 

66 dct[alias] = [ObjectId(o) for o in value] 

67 else: 

68 dct[alias] = [] 

69 for v in value: 

70 obj_lst_elem = {} 

71 consolidate_dict(obj=v, dct=obj_lst_elem, populate=populate) 

72 dct[alias].append(obj_lst_elem) 

73 else: 

74 if db_field_info.field_type == Id and value is not None: 

75 if is_list: 

76 dct[alias] = [ObjectId(o) for o in value if ObjectId.is_valid(o)] 

77 else: 

78 dct[alias] = ObjectId(value) if ObjectId.is_valid(value) else value 

79 elif ( 

80 db_field_info.field_type == Decimal 

81 or db_field_info.field_type == DbDecimal 

82 ) and value is not None: 

83 if is_list: 

84 dct[alias] = [Decimal128(str(o)) for o in value] 

85 else: 

86 dct[alias] = Decimal128(str(value)) 

87 else: 

88 dct[alias] = value 

89 return dct 

90 

91 

92def _skip_and_limit_stages(current_page: int, docs_per_page: int): 

93 """ 

94 Add pagination stages to the aggregation pipeline. 

95 

96 Args: 

97 current_page (int): The current page number. 

98 docs_per_page (int): The number of documents per page. 

99 """ 

100 max_docs_per_page = 1000 

101 current_page = 1 if current_page <= 0 else current_page 

102 docs_per_page = ( 

103 max_docs_per_page if docs_per_page > max_docs_per_page else docs_per_page 

104 ) 

105 skip = (docs_per_page * current_page) - docs_per_page 

106 skip_stage = [{"$skip": skip}] 

107 limit_stage = [{"$limit": docs_per_page}] 

108 return skip_stage, limit_stage 

109 

110 

111def mount_base_pipeline( 

112 Model, 

113 query: dict, 

114 sort: dict, 

115 populate: bool, 

116 pipeline: list | None, 

117 populate_db_fields: list[DbField] | None, 

118 paginate: int, 

119 current_page: int, 

120 docs_per_page: int, 

121 no_paginate_limit: int | None, 

122): 

123 """ 

124 Mounts a base MongoDB aggregation pipeline for find operations. 

125 

126 This function constructs a pipeline by combining various stages, such as 

127 document matching, sorting, reference population, and pagination, 

128 based on the provided parameters. It serves as the core pipeline 

129 builder for find queries. 

130 

131 Args: 

132 Model: The PyODMongo model class for the query. 

133 query: The MongoDB query dictionary for the `$match` stage. 

134 sort: The dictionary defining the sort order for the `$sort` stage. 

135 populate: If True, adds stages to populate referenced documents. 

136 pipeline: An optional list of aggregation stages to prepend to the 

137 pipeline. If None, the model's default `_pipeline` is used. 

138 populate_db_fields: A list of specific `DbField`s to populate. 

139 Used only if `populate` is True. 

140 paginate: If True, enables pagination by adding `$skip` and `$limit` 

141 stages. 

142 current_page: The page number to retrieve when pagination is enabled. 

143 docs_per_page: The number of documents per page for pagination. 

144 no_paginate_limit: The maximum number of documents to return when 

145 pagination is disabled. 

146 is_find_one: If True, adds a `$limit: 1` stage to the pipeline, 

147 optimizing the query for a single document. 

148 

149 Returns: 

150 A list representing the complete MongoDB aggregation pipeline. 

151 """ 

152 match_stage = [{"$match": query}] 

153 sort_stage = [{"$sort": sort}] if sort != {} else [] 

154 skip_stage = [] 

155 limit_stage = [{"$limit": no_paginate_limit}] if no_paginate_limit else [] 

156 pipeline = pipeline or Model._pipeline 

157 if paginate: 

158 skip_stage, limit_stage = _skip_and_limit_stages( 

159 current_page=current_page, docs_per_page=docs_per_page 

160 ) 

161 if pipeline: 

162 return match_stage + pipeline + sort_stage + skip_stage + limit_stage 

163 if populate: 

164 reference_stage = resolve_reference_pipeline( 

165 cls=Model, pipeline=[], populate_db_fields=populate_db_fields 

166 ) 

167 return match_stage + reference_stage + sort_stage + skip_stage + limit_stage 

168 else: 

169 return match_stage + sort_stage + skip_stage + limit_stage