Coverage for pyodmongo/services/reference_pipeline.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 15:08 +0000

1from pydantic import BaseModel 

2from ..models.db_field_info import DbField 

3from .verify_subclasses import is_subclass 

4from .aggregate_stages import ( 

5 unwind, 

6 group_set_replace_root, 

7 unset, 

8 lookup, 

9 set_, 

10) 

11import copy 

12 

13 

14def _paths_to_ref_ids( 

15 cls: BaseModel, 

16 paths: list, 

17 db_field_path: list, 

18 populate_db_fields: list[DbField] | None, 

19): 

20 for field in cls.model_fields: 

21 try: 

22 db_field: DbField = getattr(cls, field) 

23 except AttributeError: 

24 if is_subclass(class_to_verify=cls, subclass=BaseModel): 

25 raise TypeError( 

26 f"The {cls.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class" 

27 ) 

28 db_field_path.append(db_field) 

29 if db_field.by_reference: 

30 try: 

31 if db_field in populate_db_fields: 

32 paths.append(copy.deepcopy(db_field_path)) 

33 except TypeError: 

34 paths.append(copy.deepcopy(db_field_path)) 

35 elif db_field.has_model_fields: 

36 _paths_to_ref_ids( 

37 cls=db_field.field_type, 

38 paths=paths, 

39 db_field_path=db_field_path, 

40 populate_db_fields=populate_db_fields, 

41 ) 

42 db_field_path.pop(-1) 

43 return paths 

44 

45 

46def resolve_reference_pipeline( 

47 cls: BaseModel, pipeline: list, populate_db_fields: list[DbField] | None 

48): 

49 """ 

50 Constructs a MongoDB aggregation pipeline that handles document references within a given class model. 

51 

52 Args: 

53 cls (BaseModel): The model class for which to build the reference resolution pipeline. 

54 pipeline (list): Initial pipeline stages to which reference resolution stages will be added. 

55 

56 Returns: 

57 list: The modified MongoDB aggregation pipeline including reference resolution stages. 

58 

59 Description: 

60 This function builds a pipeline that unwinds arrays, performs lookups for references, and resets 

61 document structures to correctly embed referenced documents. It supports nested models and handles 

62 complex reference structures. 

63 """ 

64 paths = _paths_to_ref_ids( 

65 cls=cls, paths=[], db_field_path=[], populate_db_fields=populate_db_fields 

66 ) 

67 for db_field_path in paths: 

68 path_str = "" 

69 unwind_index_list = ["_id"] 

70 paths_str_to_group = [] 

71 db_field: DbField = None 

72 for index, db_field in enumerate(db_field_path): 

73 db_field: DbField 

74 path_str += ( 

75 "." + db_field.field_alias if path_str != "" else db_field.field_alias 

76 ) 

77 

78 if db_field.is_list and not db_field.by_reference: 

79 unwind_index_list.append(f"__unwind_{index}") 

80 paths_str_to_group.append(path_str) 

81 pipeline += unwind( 

82 path=path_str, 

83 array_index=unwind_index_list[-1], 

84 preserve_empty=True, 

85 ) 

86 

87 pipeline += lookup( 

88 from_=db_field.field_type._collection, 

89 local_field=path_str, 

90 foreign_field="_id", 

91 as_=path_str, 

92 pipeline=resolve_reference_pipeline( 

93 cls=db_field.field_type, 

94 pipeline=[], 

95 populate_db_fields=populate_db_fields, 

96 ), 

97 ) 

98 if not db_field.is_list: 

99 pipeline += set_(local_field=path_str, as_=path_str) 

100 

101 index_to_unset = [] 

102 for index, path_str in enumerate(reversed(paths_str_to_group)): 

103 id_ = [ 

104 f"${e}" for e in unwind_index_list[: len(unwind_index_list) - index - 1] 

105 ] 

106 current_index = unwind_index_list[-(index + 1)] 

107 index_to_unset.append(current_index) 

108 to_sort = {key: 1 for key in unwind_index_list[1:]} 

109 pipeline += group_set_replace_root( 

110 to_sort=to_sort, 

111 id_=id_, 

112 array_index=current_index, 

113 field=path_str.split(".")[-1], 

114 path_str=path_str, 

115 ) 

116 if len(index_to_unset) > 0: 

117 pipeline += unset(fields=index_to_unset) 

118 

119 return pipeline