Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 197 additions & 55 deletions nodescraper/interfaces/dataplugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from nodescraper.interfaces.plugin import PluginInterface
from nodescraper.models import (
AnalyzerArgs,
CollectorArgs,
DataModel,
DataPluginResult,
PluginResult,
Expand All @@ -51,6 +52,17 @@
from .task import SystemCompatibilityError
from .taskresulthook import TaskResultHook

CollectorClasses = Union[
Type[DataCollector],
tuple[Type[DataCollector], ...],
list[Type[DataCollector]],
]

CollectorArgsClasses = Union[
Type[CollectorArgs],
dict[str, Type[CollectorArgs]],
]


class DataPlugin(
PluginInterface, Generic[TConnectionManager, TConnectArg, TDataModel, TCollectArg, TAnalyzeArg]
Expand All @@ -61,7 +73,9 @@ class DataPlugin(

CONNECTION_TYPE: Optional[Type[TConnectionManager]]

COLLECTOR: Optional[Type[DataCollector]] = None
COLLECTOR: Optional[CollectorClasses] = None

COLLECTOR_ARGS: Optional[CollectorArgsClasses] = None

ANALYZER: Optional[Type[DataAnalyzer]] = None

Expand Down Expand Up @@ -101,6 +115,43 @@ def __init__(
)
self._data: Optional[TDataModel] = None

@classmethod
def get_collector_classes(cls) -> tuple[Type[DataCollector], ...]:
"""Return all collector classes configured on this plugin."""
collector = cls.COLLECTOR
if collector is None:
return ()
if isinstance(collector, (tuple, list)):
return tuple(collector)
return (collector,)

@classmethod
def _collector_args_class(
cls, collector_cls: Type[DataCollector]
) -> Optional[Type[CollectorArgs]]:
collector_args = cls.COLLECTOR_ARGS
if isinstance(collector_args, dict):
return collector_args.get(collector_cls.__name__)
return collector_args

@classmethod
def _validate_collector_args(cls) -> None:
collector_args = cls.COLLECTOR_ARGS
if collector_args is None:
return
if isinstance(collector_args, dict):
for collector_name, args_cls in collector_args.items():
if not isinstance(args_cls, type) or not issubclass(args_cls, CollectorArgs):
raise TypeError(
f"COLLECTOR_ARGS[{collector_name!r}] must be a CollectorArgs subclass, "
f"got {args_cls!r}"
)
return
if not isinstance(collector_args, type) or not issubclass(collector_args, CollectorArgs):
raise TypeError(
f"COLLECTOR_ARGS must be a CollectorArgs subclass or dict, got {collector_args!r}"
)

@classmethod
def _validate_class_var(cls):
if not hasattr(cls, "DATA_MODEL"):
Expand All @@ -109,12 +160,96 @@ def _validate_class_var(cls):
if cls.DATA_MODEL is None:
raise TypeError("DATA_MODEL class variable not defined")

if not cls.COLLECTOR and not cls.ANALYZER:
if not cls.get_collector_classes() and not cls.ANALYZER:
raise TypeError("No collector or analyzer task defined")

if cls.COLLECTOR and not cls.CONNECTION_TYPE:
if cls.get_collector_classes() and not cls.CONNECTION_TYPE:
raise TypeError("CONNECTION_TYPE must be defined for collector")

for collector_cls in cls.get_collector_classes():
if not isinstance(collector_cls, type) or not issubclass(collector_cls, DataCollector):
raise TypeError(
f"COLLECTOR entries must be DataCollector subclasses, got {collector_cls!r}"
)

cls._validate_collector_args()

@classmethod
def _merge_collected_data(
cls,
existing: Optional[TDataModel],
new_data: Optional[TDataModel],
) -> Optional[TDataModel]:
if new_data is None:
return existing
if existing is None:
return new_data
if not isinstance(new_data, existing.__class__):
raise TypeError(
f"Collector returned {new_data.__class__.__name__}, "
f"expected {existing.__class__.__name__}"
)
merged = {
**existing.model_dump(exclude_unset=True),
**new_data.model_dump(exclude_unset=True),
}
return existing.__class__.model_validate(merged)

@classmethod
def _aggregate_collection_results(
cls,
plugin_name: str,
results: list[TaskResult],
) -> TaskResult:
if not results:
return TaskResult(
parent=plugin_name,
status=ExecutionStatus.NOT_RAN,
message=f"Data collection not ran for {plugin_name}",
)
if len(results) == 1:
return results[0]

aggregated = TaskResult(
parent=plugin_name,
status=max(result.status for result in results),
task=",".join(result.task for result in results if result.task),
)
messages = [result.message for result in results if result.message]
if messages:
aggregated.message = "; ".join(messages)
for result in results:
aggregated.artifacts.extend(result.artifacts)
aggregated.events.extend(result.events)
aggregated.details["collector_results"] = [
result.model_dump(exclude={"artifacts", "events"}) for result in results
]
return aggregated

def _resolve_collector_args(
self,
collector_cls: Type[DataCollector],
collection_args: Optional[Union[TCollectArg, dict]],
) -> Optional[Union[TCollectArg, dict]]:
if collection_args is None:
return None

collector_name = collector_cls.__name__
collector_names = {cls.__name__ for cls in self.get_collector_classes()}
raw_args: Optional[Union[TCollectArg, dict]] = collection_args

if isinstance(collection_args, dict) and collector_names.intersection(
collection_args.keys()
):
raw_args = collection_args.get(collector_name)
if raw_args is None:
return None

args_cls = self._collector_args_class(collector_cls)
if args_cls is not None and isinstance(raw_args, dict):
return args_cls.model_validate(raw_args)
return raw_args

@classmethod
def is_valid(cls) -> bool:
"""Check that all required class variables are set
Expand Down Expand Up @@ -167,19 +302,22 @@ def collect(
Returns:
TaskResult: task result for data collection
"""
if not self.COLLECTOR:
collector_classes = self.get_collector_classes()
if not collector_classes:
self.collection_result = TaskResult(
parent=self.__class__.__name__,
status=ExecutionStatus.NOT_RAN,
message=f"Data collection not supported for {self.__class__.__name__}",
)
return self.collection_result

primary_collector = collector_classes[0]

try:
if not self.connection_manager:
if not self.CONNECTION_TYPE:
self.collection_result = TaskResult(
task=self.COLLECTOR.__name__,
task=primary_collector.__name__,
parent=self.__class__.__name__,
status=ExecutionStatus.NOT_RAN,
message=f"No connection manager type provided for {self.__class__.__name__}",
Expand All @@ -203,49 +341,53 @@ def collect(

if self.connection_manager.result.status != ExecutionStatus.OK:
self.collection_result = TaskResult(
task=self.COLLECTOR.__name__,
task=primary_collector.__name__,
parent=self.__class__.__name__,
status=ExecutionStatus.NOT_RAN,
message="Connection not available, data collection skipped",
)
else:
if (
collection_args is not None
and isinstance(collection_args, dict)
and hasattr(self, "COLLECTOR_ARGS")
and self.COLLECTOR_ARGS is not None
):
collection_args = self.COLLECTOR_ARGS.model_validate(collection_args)

collection_task = self.COLLECTOR(
system_info=self.system_info,
logger=self.logger,
system_interaction_level=system_interaction_level,
connection=self.connection_manager.connection,
max_event_priority_level=max_event_priority_level,
parent=self.__class__.__name__,
task_result_hooks=self.task_result_hooks,
log_path=self.log_path,
event_reporter=self.event_reporter,
session_id=self.session_id,
collector_results: list[TaskResult] = []
merged_data: Optional[TDataModel] = None

for collector_cls in collector_classes:
collector_args = self._resolve_collector_args(collector_cls, collection_args)
collection_task = collector_cls(
system_info=self.system_info,
logger=self.logger,
system_interaction_level=system_interaction_level,
connection=self.connection_manager.connection,
max_event_priority_level=max_event_priority_level,
parent=self.__class__.__name__,
task_result_hooks=self.task_result_hooks,
log_path=self.log_path,
event_reporter=self.event_reporter,
session_id=self.session_id,
)
result, data = collection_task.collect_data(collector_args)
collector_results.append(result)
merged_data = self._merge_collected_data(merged_data, data)

self.collection_result = self._aggregate_collection_results(
self.__class__.__name__,
collector_results,
)
self.collection_result, self._data = collection_task.collect_data(collection_args)
self._data = merged_data

except SystemCompatibilityError as e:
self.collection_result = TaskResult(
task=self.COLLECTOR.__name__,
task=primary_collector.__name__,
parent=self.__class__.__name__,
status=ExecutionStatus.NOT_RAN,
message=str(e),
)
except Exception as e:
self.logger.exception(
"Unhandled exception running collector %s for plugin %s",
self.COLLECTOR.__name__,
"Unhandled exception running collectors for plugin %s",
self.__class__.__name__,
)
self.collection_result = TaskResult(
task=self.COLLECTOR.__name__,
task=primary_collector.__name__,
parent=self.__class__.__name__,
status=ExecutionStatus.EXECUTION_FAILURE,
message=f"Unhandled exception running data collector: {str(e)}",
Expand Down Expand Up @@ -422,33 +564,33 @@ def find_datamodel_path_in_run(cls, run_path: str) -> Optional[str]:
run_path = os.path.abspath(run_path)
if not os.path.isdir(run_path):
return None
collector_cls = getattr(cls, "COLLECTOR", None)
data_model_cls = getattr(cls, "DATA_MODEL", None)
if not collector_cls or not data_model_cls:
return None
collector_dir = os.path.join(
run_path,
pascal_to_snake(cls.__name__),
pascal_to_snake(collector_cls.__name__),
)
if not os.path.isdir(collector_dir):
return None
result_path = os.path.join(collector_dir, "result.json")
if not os.path.isfile(result_path):
return None
try:
res_payload = json.loads(Path(result_path).read_text(encoding="utf-8"))
if res_payload.get("parent") != cls.__name__:
return None
except (json.JSONDecodeError, OSError):
if not data_model_cls:
return None
want_json = data_model_cls.__name__.lower() + ".json"
for fname in os.listdir(collector_dir):
low = fname.lower()
if low.endswith("datamodel.json") or low == want_json:
return os.path.join(collector_dir, fname)
if low.endswith(".log"):
return os.path.join(collector_dir, fname)
for collector_cls in cls.get_collector_classes():
collector_dir = os.path.join(
run_path,
pascal_to_snake(cls.__name__),
pascal_to_snake(collector_cls.__name__),
)
if not os.path.isdir(collector_dir):
continue
result_path = os.path.join(collector_dir, "result.json")
if not os.path.isfile(result_path):
continue
try:
res_payload = json.loads(Path(result_path).read_text(encoding="utf-8"))
if res_payload.get("parent") != cls.__name__:
continue
except (json.JSONDecodeError, OSError):
continue
want_json = data_model_cls.__name__.lower() + ".json"
for fname in os.listdir(collector_dir):
low = fname.lower()
if low.endswith("datamodel.json") or low == want_json:
return os.path.join(collector_dir, fname)
if low.endswith(".log"):
return os.path.join(collector_dir, fname)
return None

@classmethod
Expand Down
Loading
Loading