Source code for plynx.db.node_collection_manager

"""Node collection manager and utils"""

import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union

from past.builtins import basestring
from pymongo import ReturnDocument

from plynx.constants import Collections, HubSearchParams, NodeRunningStatus, NodeStatus
from plynx.db.node import Node
from plynx.utils.common import ObjectId, parse_search_string, to_object_id
from plynx.utils.db_connector import get_db_connector
from plynx.utils.hub_node_registry import registry

[docs]_PROPERTIES_TO_GET_FROM_SUBS = ['node_running_status', 'logs', 'outputs', 'cache_url']
[docs]class NodeCollectionManager: """NodeCollectionManager contains all the operations to work with Nodes in the database.""" def __init__(self, collection: str): super().__init__() self.collection: str = collection # pylint: disable=too-many-arguments
[docs] def get_db_objects( self, status: Union[List[str], str] = '', node_kinds: Union[None, str, List[str]] = None, search: str = '', per_page: int = 20, offset: int = 0, user_id: Optional[ObjectId] = None, ) -> Optional[List[Dict]]: """Get subset of the Objects. Args: status (str, None): Node Running Status search (str, None): Search pattern per_page (int): Number of Nodes per page offset (int): Offset Return: (list of dict) List of Nodes in dict format """ # pylint: disable=too-many-branches if status and isinstance(status, basestring): status = [status] if node_kinds and isinstance(node_kinds, basestring): node_kinds = [node_kinds] aggregate_list: List[Dict[str, Any]] = [] search_parameters, search_string = parse_search_string(search) # Match and_query: Dict[str, Union[ObjectId, Dict[str, Union[str, List[str], Dict]]]] = {} if node_kinds: and_query['kind'] = {'$in': node_kinds} if status: and_query['node_status'] = {'$in': status} if search_string: and_query['$text'] = {'$search': search_string} if 'original_node_id' in search_parameters: and_query['original_node_id'] = to_object_id(search_parameters['original_node_id']) if HubSearchParams.INPUT_FILE_TYPE in search_parameters: and_query['inputs'] = {"$elemMatch": {"file_type": search_parameters[HubSearchParams.INPUT_FILE_TYPE]}} if HubSearchParams.OUTPUT_FILE_TYPE in search_parameters: and_query['outputs'] = {"$elemMatch": {"file_type": search_parameters[HubSearchParams.OUTPUT_FILE_TYPE]}} if len(and_query) > 0: aggregate_list.append({"$match": and_query}) # Join with users aggregate_list.append({ '$lookup': { 'from': 'users', 'localField': 'author', 'foreignField': '_id', 'as': '_user' } }) # rm password hash aggregate_list.append({ "$project": { "_user.password_hash": 0, } }) # Match username and_query = {} if 'author' in search_parameters: and_query['_user.username'] = search_parameters['author'] if len(and_query): aggregate_list.append({"$match": and_query}) # sort sort_dict = OrderedDict() if 'sort' in search_parameters: # TODO more sort options if search_parameters['sort'] == 'starred': sort_dict['starred'] = -1 sort_dict['insertion_date'] = -1 aggregate_list.append({ "$sort": sort_dict } ) aggregate_list.append({ "$addFields": { '_readonly': {'$ne': ["$author", to_object_id(user_id)]}, } }) # counts and pagination aggregate_list.append({ '$facet': { "metadata": [{"$count": "total"}], "list": [{"$skip": int(offset)}, {"$limit": int(per_page)}], } }) # TODO: remove `allowDiskUse=True` and update indexes instead return next(get_db_connector()[self.collection].aggregate(aggregate_list, allowDiskUse=True), None)
[docs] def get_db_objects_by_ids(self, ids: Union[List[ObjectId], List[str]], collection: Optional[str] = None) -> List[dict]: """Find all the Objects with a given IDs. Args: ids (list of ObjectID): Object Ids """ if collection == Collections.HUB_NODE_REGISTRY: # TODO separate types of ids into different functions db_objects = map(lambda node: node.to_dict(), registry.find_nodes(ids)) # type: ignore else: db_objects = get_db_connector()[collection or self.collection].find({ '_id': { '$in': list(ids) } }) return list(db_objects)
[docs] def _update_sub_nodes_fields( self, sub_nodes_dicts: List[Dict], reference_node_id: str, target_props: List[str], reference_collection: Optional[str] = None ): reference_collection = reference_collection or self.collection id_to_updated_node_dict = {} function_location_to_updated_node_dict = {} upd_node_ids = list(map(lambda node_dict: node_dict.get(reference_node_id, "unknown"), sub_nodes_dicts)) for upd_node_dict in self.get_db_objects_by_ids(upd_node_ids, collection=reference_collection): id_to_updated_node_dict[upd_node_dict['_id']] = upd_node_dict function_location_to_updated_node_dict[upd_node_dict.get("code_function_location", "unknown")] = upd_node_dict for sub_node_dict in sub_nodes_dicts: if sub_node_dict.get(reference_node_id, "unknown") not in id_to_updated_node_dict: continue for prop in target_props: sub_node_dict[prop] = id_to_updated_node_dict[sub_node_dict.get(reference_node_id, "unknown")][prop] if reference_collection == Collections.HUB_NODE_REGISTRY: # special case: we need to compare not target_props, but rather assign it assert len(target_props) == 1, "Only node_status can be assigned" assert target_props[0] == 'node_status', "Only node_status can be assigned" for sub_node_dict in sub_nodes_dicts: if sub_node_dict.get(reference_node_id, "unknown") is None: continue if sub_node_dict.get(reference_node_id, "unknown") not in function_location_to_updated_node_dict: logging.warning(f"`{sub_node_dict.get(reference_node_id, 'unknown')}` is not found in the list of operation locations") continue if sub_node_dict['code_hash'] != function_location_to_updated_node_dict[sub_node_dict.get(reference_node_id, "unknown")]["code_hash"]: sub_node_dict['node_status'] = NodeStatus.DEPRECATED
[docs] def get_db_node(self, node_id: ObjectId, user_id: Optional[ObjectId] = None) -> Optional[Dict]: """Get dict representation of a Node. Args: node_id (ObjectId, str): Object ID user_id (str, ObjectId, None): User ID Return: (dict) dict representation of the Object """ res = self.get_db_object(node_id, user_id) if not res: return None sub_nodes_dicts: Optional[List[Dict]] = None for parameter in res['parameters']: if parameter['name'] == '_nodes': sub_nodes_dicts = parameter['value']['value'] break if sub_nodes_dicts: # TODO join collections using database capabilities if self.collection == Collections.RUNS: self._update_sub_nodes_fields(sub_nodes_dicts, '_id', _PROPERTIES_TO_GET_FROM_SUBS) self._update_sub_nodes_fields(sub_nodes_dicts, 'original_node_id', ['node_status'], reference_collection=Collections.TEMPLATES) self._update_sub_nodes_fields(sub_nodes_dicts, 'code_function_location', ['node_status'], reference_collection=Collections.HUB_NODE_REGISTRY) return res
[docs] def get_db_object(self, object_id: ObjectId, user_id: Optional[ObjectId] = None) -> Optional[Dict]: """Get dict representation of an Object. Args: object_id (ObjectId): Object ID user_id (ObjectId, None): User ID Return: (dict) dict representation of the Object """ res = get_db_connector()[self.collection].find_one({'_id': to_object_id(object_id)}) if not res: return None res['_readonly'] = (user_id != to_object_id(res['author'])) return res
@staticmethod
[docs] def _transplant_node(node: Node, dest_node: Node) -> Node: if dest_node._id == node.original_node_id: return node dest_node.apply_properties(node) dest_node.original_node_id = dest_node._id dest_node.parent_node_id = dest_node.successor_node_id = None dest_node._id = node._id return dest_node
[docs] def upgrade_sub_nodes(self, main_node: Node) -> int: """Upgrade deprecated Nodes. The function does not change the original graph in the database. Return: (int): Number of upgraded Nodes """ assert self.collection == Collections.TEMPLATES sub_nodes = main_node.get_sub_nodes() node_ids = [node.original_node_id for node in sub_nodes if node.original_node_id] db_nodes = self.get_db_objects_by_ids(node_ids) new_node_db_mapping = {} # ------------------------ # Update nodes from the DB # ------------------------ for db_node in db_nodes: original_node_id = to_object_id(db_node['_id']) new_db_node = db_node if original_node_id not in new_node_db_mapping: while new_db_node['node_status'] != NodeStatus.READY and 'successor_node_id' in new_db_node and new_db_node['successor_node_id']: tmp_node = self.get_db_node(new_db_node['successor_node_id']) if tmp_node: new_db_node = tmp_node else: break new_node_db_mapping[original_node_id] = new_db_node new_nodes = [ NodeCollectionManager._transplant_node( node, Node.from_dict(new_node_db_mapping[to_object_id(node.original_node_id)]) ) if to_object_id(node.original_node_id) in new_node_db_mapping else node for node in sub_nodes ] # ------------------------- # Update nodes from the hub # ------------------------- node_locations = list(set(map(lambda node: node.code_function_location, new_nodes))) node_locations_non_empty: List[str] = [ node_location for node_location in node_locations if node_location is not None ] hub_nodes_mapping = { node.code_function_location: node for node in registry.find_nodes(node_locations_non_empty) } new_nodes = [ NodeCollectionManager._transplant_node( node, hub_nodes_mapping[node.code_function_location] ) if node.code_function_location in hub_nodes_mapping and node.code_hash != hub_nodes_mapping[node.code_function_location].code_hash else node for node in new_nodes ] upgraded_nodes_count = sum( 1 for node, new_node in zip(sub_nodes, new_nodes) if node.original_node_id != new_node.original_node_id ) main_node.get_parameter_by_name('_nodes').value.value = new_nodes return upgraded_nodes_count
[docs] def pick_node(self, kinds: List[str]) -> Dict: """Get node and set status to RUNNING in atomic way""" node = get_db_connector()[self.collection].find_one_and_update( { '$and': [ { 'kind': { '$in': kinds, } }, { 'node_running_status': { '$in': [ NodeRunningStatus.READY, NodeRunningStatus.IN_QUEUE, ] } }, ], }, { '$set': { 'node_running_status': NodeRunningStatus.RUNNING } }, return_document=ReturnDocument.AFTER ) return node