Coverage for pyodmongo/services/model_init.py: 100%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 15:08 +0000

1from pydantic import BaseModel 

2from pymongo import IndexModel, ASCENDING, TEXT 

3from typing import Any, Union, get_origin, get_args 

4from types import UnionType 

5from ..models.id_model import Id 

6from ..models.db_field_info import DbField 

7from ..services.verify_subclasses import is_subclass 

8 

9 

10def __ordinary_index_and_text_keys( 

11 cls: BaseModel, 

12 indexes: list, 

13 text_keys: list, 

14 indexes_path: list, 

15 text_indexes_path: list, 

16): 

17 for key in cls.model_fields.keys(): 

18 try: 

19 db_field_info: DbField = getattr(cls, key) 

20 except AttributeError: 

21 if is_subclass(class_to_verify=cls, subclass=BaseModel): 

22 raise TypeError( 

23 f"The {cls.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class" 

24 ) 

25 

26 if ( 

27 db_field_info.has_model_fields 

28 and not db_field_info.by_reference 

29 and not db_field_info.is_list 

30 ): 

31 indexes_path.append(db_field_info.field_alias) 

32 text_indexes_path.append(db_field_info.field_alias) 

33 __ordinary_index_and_text_keys( 

34 cls=db_field_info.field_type, 

35 indexes=indexes, 

36 text_keys=text_keys, 

37 indexes_path=indexes_path, 

38 text_indexes_path=text_indexes_path, 

39 ) 

40 indexes_path.pop() 

41 text_indexes_path.pop() 

42 if not cls.model_fields[key].json_schema_extra: 

43 continue 

44 is_index = cls.model_fields[key].json_schema_extra.get("index") or False 

45 is_unique = cls.model_fields[key].json_schema_extra.get("unique") or False 

46 is_text_index = ( 

47 cls.model_fields[key].json_schema_extra.get("text_index") or False 

48 ) 

49 if is_index: 

50 indexes_path.append(db_field_info.field_alias) 

51 index_name = ".".join(indexes_path) 

52 from ..models.db_model import DbModel 

53 

54 # TODO: 🤮🤮🤮 This section requires a more efficient and robust solution 🤮🤮🤮 

55 # The current implementation needs improvement to optimize performance and readability. 

56 if ( 

57 is_unique 

58 and len(indexes_path) > 1 

59 and is_subclass(class_to_verify=cls, subclass=DbModel) 

60 ): 

61 is_unique = False 

62 

63 indexes.append( 

64 IndexModel([(index_name, ASCENDING)], name=index_name, unique=is_unique) 

65 ) 

66 indexes_path.pop() 

67 if is_text_index: 

68 text_indexes_path.append(db_field_info.field_alias) 

69 text_index_name = ".".join(text_indexes_path) 

70 text_keys.append((text_index_name, TEXT)) 

71 text_indexes_path.pop() 

72 return indexes, text_keys 

73 

74 

75def resolve_indexes(cls: BaseModel): 

76 """ 

77 Resolves and constructs MongoDB index models for the fields of a given class based on field attributes. 

78 

79 Args: 

80 cls (BaseModel): The model class whose indexes are to be resolved. 

81 

82 Returns: 

83 list[IndexModel]: A list of MongoDB index models constructed for the class fields. 

84 

85 Description: 

86 This method iterates over the model fields and checks for index-related attributes (index, unique, 

87 text_index). It creates appropriate MongoDB index structures (IndexModel) for these fields, 

88 including handling of text indexes and unique constraints. 

89 """ 

90 indexes = [] 

91 text_keys = [] 

92 

93 indexes, text_keys = __ordinary_index_and_text_keys( 

94 cls=cls, 

95 indexes=indexes, 

96 text_keys=text_keys, 

97 indexes_path=[], 

98 text_indexes_path=[], 

99 ) 

100 

101 if len(text_keys) > 0: 

102 default_language = cls._default_language 

103 if default_language: 

104 indexes.append( 

105 IndexModel(text_keys, name="texts", default_language=default_language) 

106 ) 

107 else: 

108 indexes.append(IndexModel(text_keys, name="texts")) 

109 

110 return indexes 

111 

112 

113def _is_union(field_type: Any): 

114 return get_origin(field_type) is UnionType or get_origin(field_type) is Union 

115 

116 

117def _has_a_list_in_union(field_type: Any): 

118 for ft in get_args(field_type): 

119 if get_origin(ft) is list: 

120 return ft 

121 

122 

123def _union_collector_info(field, args): 

124 args = get_args(args) 

125 has_any_model_field = False 

126 for arg in args: 

127 if hasattr(arg, "model_fields"): 

128 has_any_model_field = True 

129 by_reference = (Id in args) and (field != "id") and has_any_model_field 

130 if by_reference: 

131 field_type_index = 0 

132 for arg in args: 

133 if hasattr(arg, "model_fields"): 

134 break 

135 field_type_index += 1 

136 field_type = args[field_type_index] 

137 else: 

138 field_type = args[0] 

139 return field_type, by_reference 

140 

141 

142def field_annotation_infos(field, field_info) -> DbField: 

143 """ 

144 Extracts and constructs database field metadata from model field annotations. 

145 

146 Args: 

147 field (str): The name of the field in the model. 

148 field_info (Field): Pydantic model field information containing metadata. 

149 

150 Returns: 

151 DbField: A DbField instance encapsulating the database-related metadata of the model field. 

152 

153 Description: 

154 This method processes a model field's annotation to determine its database-related properties, 

155 such as whether it's a list, reference, or contains sub-model fields. It constructs and returns 

156 a DbField object that represents this metadata. 

157 """ 

158 field_annotation = field_info.annotation 

159 by_reference = False 

160 field_type = field_annotation 

161 if _is_union(field_type=field_annotation): 

162 has_a_list_in_union = _has_a_list_in_union(field_type=field_annotation) 

163 if has_a_list_in_union is not None: 

164 field_annotation = has_a_list_in_union 

165 is_list = get_origin(field_annotation) is list 

166 if is_list: 

167 args = get_args(field_annotation)[0] 

168 is_union = _is_union(args) 

169 if is_union: 

170 field_type, by_reference = _union_collector_info(field=field, args=args) 

171 else: 

172 field_type = args 

173 elif _is_union(field_annotation): 

174 field_type, by_reference = _union_collector_info( 

175 field=field, args=field_annotation 

176 ) 

177 has_model_fields = hasattr(field_type, "model_fields") 

178 field_name = field 

179 field_alias = field_info.alias or field 

180 if field_name == "id": 

181 field_alias = "_id" 

182 return DbField( 

183 field_name=field_name, 

184 field_alias=field_alias, 

185 field_type=field_type, 

186 by_reference=by_reference, 

187 is_list=is_list, 

188 has_model_fields=has_model_fields, 

189 ) 

190 

191 

192def _recursice_db_fields_info(db_field_info: DbField, path: list) -> DbField: 

193 if db_field_info.has_model_fields: 

194 for field, field_info in db_field_info.field_type.model_fields.items(): 

195 rec_db_field_info = field_annotation_infos( 

196 field=field, field_info=field_info 

197 ) 

198 path.append(rec_db_field_info.field_alias) 

199 path_str = ".".join(path) 

200 rec_db_field_info.path_str = path_str 

201 rec_db_field_info = _recursice_db_fields_info( 

202 db_field_info=rec_db_field_info, path=path 

203 ) 

204 setattr(db_field_info, field, rec_db_field_info) 

205 path.pop(-1) 

206 return db_field_info 

207 

208 

209def resolve_class_fields_db_info(cls: BaseModel): 

210 """ 

211 Resolves and sets database field information for all fields of the specified class. 

212 

213 Args: 

214 cls (BaseModel): The class whose fields are to be resolved. 

215 

216 Description: 

217 This method iterates over each field of the given class, resolves its database field information, 

218 and sets additional properties directly on the class to facilitate database operations. This includes 

219 constructing nested DbField structures for complex models. 

220 """ 

221 for field, field_info in cls.model_fields.items(): 

222 db_field_info = field_annotation_infos(field=field, field_info=field_info) 

223 path = db_field_info.field_alias 

224 db_field_info.path_str = path 

225 field_to_set = _recursice_db_fields_info( 

226 db_field_info=db_field_info, path=[path] 

227 ) 

228 setattr(cls, field + "__pyodmongo", field_to_set) 

229 

230 

231# TODO finish resolve_single_db_field 

232# def resolve_single_db_field(value: Any) -> DbField: 

233# by_reference = False 

234# return value 

235 

236 

237# def resolve_db_fields(bases: tuple[Any], db_fields: dict): 

238# for base in bases: 

239# if base is object: 

240# continue 

241# base_annotations = base.__dict__.get("__annotations__") 

242# for base_field, value in base_annotations.items(): 

243# if base_field not in db_fields.keys() and not base_field.startswith("_"): 

244# db_fields[base_field] = resolve_single_db_field(value=value) 

245# resolve_db_fields(bases=base.__bases__, db_fields=db_fields) 

246# return db_fields