diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2f18c3128..f222c17a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,6 +29,11 @@ repos: hooks: - id: isort name: isort (python) + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.18.2 + hooks: + - id: mypy + additional_dependencies: [attrs] - repo: https://github.com/google/yamlfmt rev: v0.15.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 44d151f66..b020544e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,11 @@ ## Upcoming Changes ### Breaking +* pre detector events now also include host.name if the field value is None ### Features * add support for python 3.14 +* allow pre-detector to copy a configurable list of fields from log to detection event ### Improvements * add workflow to partially run & check the compose example @@ -15,6 +17,7 @@ * fix docker-compose and k8s example setups * fix handling of non-string values (e.g. int) as replacement argument for `generic_resolver` * fix documentation for `generic_resolver` rule `append_to_list -> merge_with_target` option +* fix grokker using a fixed directory for downloaded patterns, potentially leading to conflicts between processes ## 17.0.3 ### Breaking diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 5878b172b..a33f60270 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -40,6 +40,7 @@ contribute to them. git clone https://github.com/fkie-cad/Logprep.git cd Logprep pip install . + pip install ".[dev]" # if you intend to contribute To see if the installation was successful run :code:`logprep --version`. diff --git a/logprep/abc/processor.py b/logprep/abc/processor.py index f89f5abc0..048cf6327 100644 --- a/logprep/abc/processor.py +++ b/logprep/abc/processor.py @@ -3,7 +3,7 @@ import logging import os from abc import abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Type +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Type, cast from attrs import define, field, validators @@ -105,27 +105,43 @@ class Config(Component.Config): __slots__ = [ "_event", "_rule_tree", - "result", + "_result", "_bypass_rule_tree", ] - rule_class: ClassVar["Type[Rule] | None"] = None + rule_class: ClassVar[Type["Rule"] | None] = None _event: dict _rule_tree: RuleTree _strategy = None _bypass_rule_tree: bool - result: ProcessorResult | None + _result: ProcessorResult | None def __init__(self, name: str, configuration: "Processor.Config"): super().__init__(name, configuration) self._rule_tree = RuleTree(config=self._config.tree_config) self.load_rules(rules_targets=self._config.rules) - self.result = None + self._result = None self._bypass_rule_tree = False if os.environ.get("LOGPREP_BYPASS_RULE_TREE"): self._bypass_rule_tree = True logger.debug("Bypassing rule tree for processor %s", self.name) + @property + def result(self) -> ProcessorResult: + """Returns the current result object which is guaranteed to be non-None + during processing of an event. + + Returns + ------- + ProcessorResult + The current result to be modified in-place + """ + return cast(ProcessorResult, self._result) + + @result.setter + def result(self, value: ProcessorResult): + self._result = value + @property def rules(self): """Returns all rules @@ -161,7 +177,7 @@ def process(self, event: dict) -> ProcessorResult: extra data and a list of target outputs. """ - self.result = ProcessorResult(processor_name=self.name, event=event) # type: ignore + self._result = ProcessorResult(processor_name=self.name, event=event) # type: ignore logger.debug("%s processing event %s", self.describe(), event) if self._bypass_rule_tree: self._process_all_rules(event) @@ -225,8 +241,8 @@ def _apply_rules_wrapper(self, event: dict, rule: "Rule"): event.clear() if not hasattr(rule, "delete_source_fields"): return - if rule.delete_source_fields: - for dotted_field in rule.source_fields: + if getattr(rule, "delete_source_fields", False): + for dotted_field in getattr(rule, "source_fields", []): pop_dotted_field_value(event, dotted_field) @abstractmethod @@ -298,12 +314,13 @@ def _has_missing_values(self, event, rule, source_field_dict): return False def _write_target_field(self, event: dict, rule: "Rule", result: Any) -> None: - add_fields_to( - event, - fields={rule.target_field: result}, - merge_with_target=rule.merge_with_target, - overwrite_target=rule.overwrite_target, - ) + if hasattr(rule, "target_field"): + add_fields_to( + event, + fields={getattr(rule, "target_field"): result}, + merge_with_target=getattr(rule, "merge_with_target", False), + overwrite_target=getattr(rule, "overwrite_target", False), + ) def setup(self): super().setup() diff --git a/logprep/ng/processor/grokker/processor.py b/logprep/ng/processor/grokker/processor.py index 9db91188d..8fbece481 100644 --- a/logprep/ng/processor/grokker/processor.py +++ b/logprep/ng/processor/grokker/processor.py @@ -31,6 +31,7 @@ import logging import re +import tempfile from pathlib import Path from zipfile import ZipFile @@ -99,11 +100,13 @@ def setup(self) -> None: super().setup() custom_patterns_dir = self._config.custom_patterns_dir if re.search(r"http(s)?:\/\/.*?\.zip", custom_patterns_dir): - patterns_tmp_path = Path("/tmp/grok_patterns") - self._download_zip_file(source_file=custom_patterns_dir, target_dir=patterns_tmp_path) - for rule in self.rules: - rule.set_mapping_actions(patterns_tmp_path) - return + with tempfile.TemporaryDirectory("grok") as patterns_tmp_path: + self._download_zip_file( + source_file=custom_patterns_dir, target_dir=Path(patterns_tmp_path) + ) + for rule in self.rules: + rule.set_mapping_actions(patterns_tmp_path) + return if custom_patterns_dir: for rule in self.rules: rule.set_mapping_actions(custom_patterns_dir) @@ -111,12 +114,10 @@ def setup(self) -> None: for rule in self.rules: rule.set_mapping_actions() - def _download_zip_file(self, source_file: str, target_dir: Path) -> None: - if not target_dir.exists(): - logger.debug("start grok pattern download...") - archive = Path(f"{target_dir}.zip") - archive.touch() - archive.write_bytes(GetterFactory.from_string(source_file).get_raw()) + def _download_zip_file(self, source_file: str, target_dir: Path): + logger.debug("start grok pattern download...") + with tempfile.TemporaryFile("wb+") as archive: + archive.write(GetterFactory.from_string(source_file).get_raw()) logger.debug("finished grok pattern download.") - with ZipFile(str(archive), mode="r") as zip_file: + with ZipFile(archive, mode="r") as zip_file: zip_file.extractall(target_dir) diff --git a/logprep/ng/processor/pre_detector/processor.py b/logprep/ng/processor/pre_detector/processor.py index 64843ca88..3489d5e11 100644 --- a/logprep/ng/processor/pre_detector/processor.py +++ b/logprep/ng/processor/pre_detector/processor.py @@ -29,6 +29,7 @@ """ from functools import cached_property +from typing import cast from uuid import uuid4 from attr import define, field, validators @@ -38,7 +39,12 @@ from logprep.processor.base.exceptions import ProcessingWarning from logprep.processor.pre_detector.ip_alerter import IPAlerter from logprep.processor.pre_detector.rule import PreDetectorRule -from logprep.util.helper import add_fields_to, get_dotted_field_value +from logprep.util.helper import ( + FieldValue, + add_fields_to, + copy_fields_to_event, + get_dotted_field_value, +) from logprep.util.time import TimeParser, TimeParserException @@ -92,16 +98,16 @@ class Config(Processor.Config): def _ip_alerter(self) -> IPAlerter: return IPAlerter(self._config.alert_ip_list_path) - def normalize_timestamp(self, rule: PreDetectorRule, timestamp: str) -> str: + def normalize_timestamp(self, rule: PreDetectorRule, timestamp: FieldValue) -> str: """method for normalizing the timestamp""" try: parsed_datetime = TimeParser.parse_datetime( - timestamp, rule.source_format, rule.source_timezone + cast(str, timestamp), rule.source_format, rule.source_timezone ) return ( parsed_datetime.astimezone(rule.target_timezone).isoformat().replace("+00:00", "Z") ) - except TimeParserException as error: + except (TimeParserException, TypeError) as error: raise ProcessingWarning( "Could not parse timestamp", rule, @@ -132,7 +138,7 @@ def _get_detection_result(self, event: dict, rule: PreDetectorRule) -> None: @staticmethod def _generate_detection_result( - pre_detection_id: str, event: dict, rule: PreDetectorRule + pre_detection_id: FieldValue, event: dict, rule: PreDetectorRule ) -> dict: detection_result = { **rule.detection_data, @@ -140,7 +146,11 @@ def _generate_detection_result( "description": rule.description, "pre_detection_id": pre_detection_id, } - - if host_name := get_dotted_field_value(event, "host.name"): - detection_result.update({"host": {"name": host_name}}) + copy_fields_to_event( + target_event=detection_result, + source_event=event, + dotted_field_names=rule.copy_fields_to_detection_event, + rule=rule, + skip_missing=True, + ) return detection_result diff --git a/logprep/processor/base/rule.py b/logprep/processor/base/rule.py index 35ea56c2c..d085704e4 100644 --- a/logprep/processor/base/rule.py +++ b/logprep/processor/base/rule.py @@ -230,13 +230,15 @@ class Metrics(Component.Metrics): ) """Number of errors that occurred while processing events""" - special_field_types = [ - "regex_fields", - "sigma_fields", - "ip_fields", - "tests", - "tag_on_failure", - ] + special_field_types = frozenset( + ( + "regex_fields", + "sigma_fields", + "ip_fields", + "tests", + "tag_on_failure", + ) + ) rule_type: str = "" @@ -272,7 +274,7 @@ def __init__(self, filter_rule: FilterExpression, config: Config, processor_name raise InvalidRuleDefinitionError("config is not a Config class") if not config.tag_on_failure: config.tag_on_failure = [f"_{self.rule_type}_failure"] - self.__class__.__hash__ = Rule.__hash__ + self.__class__.__hash__ = Rule.__hash__ # type: ignore self._processor_name = processor_name self.filter_str = str(filter_rule) self._filter = filter_rule @@ -287,10 +289,12 @@ def metrics(self): """create and return metrics object""" return self.Metrics(labels=self.metric_labels) - def __eq__(self, other: "Rule") -> bool: - return all([other.filter == self._filter, other._config == self._config]) + def __eq__(self, other: object) -> bool: + if not isinstance(other, Rule): + return NotImplemented + return other.filter == self._filter and other._config == self._config - def __hash__(self) -> int: # pylint: disable=function-redefined + def __hash__(self) -> int: return id(self) def __repr__(self) -> str: @@ -366,7 +370,7 @@ def _check_rule_validity( rule: dict, *extra_keys: str, optional_keys: Optional[Set[str]] = None ): optional_keys = optional_keys if optional_keys else set() - keys = [i for i in rule if i not in ["description"] + Rule.special_field_types] + keys = [i for i in rule if i not in {"description", *Rule.special_field_types}] required_keys = ["filter"] + list(extra_keys) if not keys or set(keys) != set(required_keys): diff --git a/logprep/processor/grokker/processor.py b/logprep/processor/grokker/processor.py index 24a1ac17a..34c453396 100644 --- a/logprep/processor/grokker/processor.py +++ b/logprep/processor/grokker/processor.py @@ -31,6 +31,7 @@ import logging import re +import tempfile from pathlib import Path from zipfile import ZipFile @@ -106,16 +107,18 @@ def _apply_rules(self, event: dict, rule: GrokkerRule): if not matches: raise ProcessingWarning("no grok pattern matched", rule, event) - def setup(self): + def setup(self) -> None: """Loads the action mapping. Has to be called before processing""" super().setup() custom_patterns_dir = self._config.custom_patterns_dir if re.search(r"http(s)?:\/\/.*?\.zip", custom_patterns_dir): - patterns_tmp_path = Path("/tmp/grok_patterns") - self._download_zip_file(source_file=custom_patterns_dir, target_dir=patterns_tmp_path) - for rule in self.rules: - rule.set_mapping_actions(patterns_tmp_path) - return + with tempfile.TemporaryDirectory("grok") as patterns_tmp_path: + self._download_zip_file( + source_file=custom_patterns_dir, target_dir=Path(patterns_tmp_path) + ) + for rule in self.rules: + rule.set_mapping_actions(patterns_tmp_path) + return if custom_patterns_dir: for rule in self.rules: rule.set_mapping_actions(custom_patterns_dir) @@ -124,11 +127,9 @@ def setup(self): rule.set_mapping_actions() def _download_zip_file(self, source_file: str, target_dir: Path): - if not target_dir.exists(): - logger.debug("start grok pattern download...") - archive = Path(f"{target_dir}.zip") - archive.touch() - archive.write_bytes(GetterFactory.from_string(source_file).get_raw()) + logger.debug("start grok pattern download...") + with tempfile.TemporaryFile("wb+") as archive: + archive.write(GetterFactory.from_string(source_file).get_raw()) logger.debug("finished grok pattern download.") - with ZipFile(str(archive), mode="r") as zip_file: + with ZipFile(archive, mode="r") as zip_file: zip_file.extractall(target_dir) diff --git a/logprep/processor/pre_detector/processor.py b/logprep/processor/pre_detector/processor.py index 15f93c066..e6998dcba 100644 --- a/logprep/processor/pre_detector/processor.py +++ b/logprep/processor/pre_detector/processor.py @@ -29,6 +29,7 @@ """ from functools import cached_property +from typing import cast from uuid import uuid4 from attr import define, field, validators @@ -37,7 +38,12 @@ from logprep.processor.base.exceptions import ProcessingWarning from logprep.processor.pre_detector.ip_alerter import IPAlerter from logprep.processor.pre_detector.rule import PreDetectorRule -from logprep.util.helper import add_fields_to, get_dotted_field_value +from logprep.util.helper import ( + FieldValue, + add_fields_to, + copy_fields_to_event, + get_dotted_field_value, +) from logprep.util.time import TimeParser, TimeParserException @@ -101,19 +107,19 @@ class Config(Processor.Config): rule_class = PreDetectorRule @cached_property - def _ip_alerter(self): + def _ip_alerter(self) -> IPAlerter: return IPAlerter(self._config.alert_ip_list_path) - def normalize_timestamp(self, rule: PreDetectorRule, timestamp: str) -> str: + def normalize_timestamp(self, rule: PreDetectorRule, timestamp: FieldValue) -> str: """method for normalizing the timestamp""" try: parsed_datetime = TimeParser.parse_datetime( - timestamp, rule.source_format, rule.source_timezone + cast(str, timestamp), rule.source_format, rule.source_timezone ) return ( parsed_datetime.astimezone(rule.target_timezone).isoformat().replace("+00:00", "Z") ) - except TimeParserException as error: + except (TimeParserException, TypeError) as error: raise ProcessingWarning( "Could not parse timestamp", rule, @@ -143,16 +149,19 @@ def _get_detection_result(self, event: dict, rule: PreDetectorRule): @staticmethod def _generate_detection_result( - pre_detection_id: str, event: dict, rule: PreDetectorRule + pre_detection_id: FieldValue, event: dict, rule: PreDetectorRule ) -> dict: - detection_result = rule.detection_data - detection_result.update( - { - "rule_filter": rule.filter_str, - "description": rule.description, - "pre_detection_id": pre_detection_id, - } + detection_result = { + **rule.detection_data, + "rule_filter": rule.filter_str, + "description": rule.description, + "pre_detection_id": pre_detection_id, + } + copy_fields_to_event( + target_event=detection_result, + source_event=event, + dotted_field_names=rule.copy_fields_to_detection_event, + rule=rule, + skip_missing=True, ) - if host_name := get_dotted_field_value(event, "host.name"): - detection_result.update({"host": {"name": host_name}}) return detection_result diff --git a/logprep/processor/pre_detector/rule.py b/logprep/processor/pre_detector/rule.py index e8a80f90f..082883f11 100644 --- a/logprep/processor/pre_detector/rule.py +++ b/logprep/processor/pre_detector/rule.py @@ -121,25 +121,46 @@ """ from functools import cached_property -from typing import Optional, Union +from types import MappingProxyType +from typing import cast from zoneinfo import ZoneInfo -from attrs import asdict, define, field, validators +from attrs import asdict, define, field, fields, validators from logprep.processor.base.rule import Rule - -class PreDetectorRule(Rule): - """Check if documents match a filter.""" - - special_field_types = { +SPECIAL_FIELD_TYPES = frozenset( + ( *Rule.special_field_types, "source_format", "source_timezone", "target_timezone", "timestamp_field", "failure_tags", - } + "copy_fields_to_detection_event", + ) +) + + +def _validate_copy_fields_to_detection_event(config: "PreDetectorRule.Config", _, value: set[str]): + field_names_set_by_processor = {"rule_filter", "description", "pre_detection_id"} + + rule_config_field_names = set(f.name for f in fields(type(config))) + field_names_set_by_rule = rule_config_field_names - SPECIAL_FIELD_TYPES + + illegal_field_names = field_names_set_by_processor | field_names_set_by_rule + + if value & illegal_field_names: + raise ValueError( + f"Illegal fields specified for `copy_fields_to_detection_event`. " + f"Fields ({', '.join(value & illegal_field_names)}) are not allowed. " + ) + + +class PreDetectorRule(Rule): + """Check if documents match a filter.""" + + special_field_types = SPECIAL_FIELD_TYPES @define(kw_only=True) class Config(Rule.Config): # pylint: disable=too-many-instance-attributes @@ -160,15 +181,15 @@ class Config(Rule.Config): # pylint: disable=too-many-instance-attributes which can be configured in the pipeline for the pre_detector. If this field was specified, then the rule will *only* trigger in case one of the IPs from the list is also available in the specified fields.""" - sigma_fields: Union[list, bool] = field( + sigma_fields: list | bool = field( validator=validators.instance_of((list, bool)), factory=list ) """tbd""" - link: Optional[str] = field( + link: str | None = field( validator=validators.optional(validators.instance_of(str)), default=None ) """A link to the rule if applicable.""" - source_format: list = field( + source_format: str = field( validator=validators.instance_of(str), default="ISO8601", ) @@ -187,6 +208,28 @@ class Config(Rule.Config): # pylint: disable=too-many-instance-attributes validator=validators.instance_of(list), default=["pre_detector_failure"] ) """ tags to be added if processing of the rule fails""" + copy_fields_to_detection_event: set[str] = field( + validator=[ + validators.deep_iterable( + member_validator=validators.instance_of(str), + iterable_validator=validators.or_( + validators.instance_of(set), validators.instance_of(list) + ), + ), + _validate_copy_fields_to_detection_event, + ], + converter=set, + default={"host.name"}, + ) + """ + Field (names) from the triggering event to be added to the detection events. + Defaults to ["host.name"] for downwards compatibility reasons. + """ + + @property + def config(self) -> Config: + """Provides the properly typed rule configuration object""" + return cast("PreDetectorRule.Config", self._config) def __eq__(self, other: object) -> bool: if not isinstance(other, PreDetectorRule): @@ -200,36 +243,40 @@ def __eq__(self, other: object) -> bool: # pylint: disable=C0111 @cached_property - def detection_data(self) -> dict: + def detection_data(self) -> MappingProxyType: detection_data = asdict( - self._config, filter=lambda attribute, _: attribute.name not in self.special_field_types + self.config, filter=lambda attribute, _: attribute.name not in self.special_field_types ) - if self._config.link is None: + if self.config.link is None: del detection_data["link"] - return detection_data + return MappingProxyType(detection_data) @property def ip_fields(self) -> list: - return self._config.ip_fields + return self.config.ip_fields @property def description(self) -> str: - return self._config.description + return self.config.description @property def source_format(self) -> str: - return self._config.source_format + return self.config.source_format @property def target_timezone(self) -> str: - return self._config.target_timezone + return self.config.target_timezone @property def source_timezone(self) -> str: - return self._config.source_timezone + return self.config.source_timezone @property def timestamp_field(self) -> str: - return self._config.timestamp_field + return self.config.timestamp_field + + @property + def copy_fields_to_detection_event(self) -> set[str]: + return self.config.copy_fields_to_detection_event # pylint: enable=C0111 diff --git a/logprep/util/helper.py b/logprep/util/helper.py index d7ca60f9d..70d90cad4 100644 --- a/logprep/util/helper.py +++ b/logprep/util/helper.py @@ -3,10 +3,11 @@ import itertools import re import sys +from enum import Enum, auto from functools import lru_cache, partial, reduce from importlib.metadata import version from os import remove -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Callable, Iterable, Optional, TypeAlias, Union, cast from logprep.processor.base.exceptions import FieldExistsWarning from logprep.util.ansi import AnsiBack, AnsiFore, Back, Fore @@ -17,9 +18,32 @@ from logprep.util.configuration import Configuration -def color_print_line( - back: Optional[Union[str, AnsiBack]], fore: Optional[Union[str, AnsiBack]], message: str -): +class Missing(Enum): + """Sentinel type for indicating missing fields.""" + + MISSING = auto() + + +MISSING = Missing.MISSING # pylint: disable=invalid-name +"""Sentinel value for indicating missing fields.""" + + +class Skip(Enum): + """Sentinel type for method instrumentation to skip fields.""" + + SKIP = auto() + + +SKIP = Skip.SKIP # pylint: disable=invalid-name +"""Sentinel value for method instrumentation to skip fields.""" + + +FieldValue: TypeAlias = Union[ + dict[str, "FieldValue"], list["FieldValue"], str, int, float, bool, None +] + + +def color_print_line(back: str | AnsiBack | None, fore: str | AnsiBack | None, message: str): """Print string with colors and reset the color afterwards.""" color = "" if back: @@ -30,7 +54,8 @@ def color_print_line( print(color + message + Fore.RESET + Back.RESET) -def color_print_title(background: Union[str, AnsiBack], message: str): +def color_print_title(background: str | AnsiBack, message: str): + """Print dashed title line with black foreground colour and reset the color afterwards.""" message = f"------ {message} ------" color_print_line(background, Fore.BLACK, message) @@ -61,32 +86,36 @@ def _add_and_not_overwrite_key(sub_dict, key): def _add_field_to( event: dict, field: tuple, - rule: "Rule", + rule: Optional["Rule"], merge_with_target: bool = False, overwrite_target: bool = False, ) -> None: """ Add content to the target_field in the given event. target_field can be a dotted subfield. In case of missing fields, all intermediate fields will be created. + Parameters ---------- event: dict Original log-event that logprep is currently processing field: tuple - A key value pair describing the field that should be added. The key is the dotted subfield string indicating - the target. The value is the content that should be added to the named target. The content can be of type - str, float, int, list, dict. + A key value pair describing the field that should be added. The key is the dotted subfield string + indicating the target. The value is the content that should be added to the named target. + The content can be of type str, float, int, list, dict. rule: Rule A rule that initiated the field addition, is used for proper error handling. - merge_with_target: bool - Flag that determines whether the content should be merged with an existing target_field - overwrite_target: bool - Flag that determines whether the target_field should be overwritten by content + merge_with_target: bool, optional + Flag that determines whether the content should be merged with an existing target_field. + Defaults to False. + overwrite_target: bool, optional + Flag that determines whether the target_field should be overwritten by content. + Defaults to False. + Raises ------ FieldExistsWarning - If the target_field already exists and overwrite_target_field is False, or if extends_lists is True but - the existing field is not a list. + If the target_field already exists and overwrite_target_field is False, + or if extends_lists is True but the existing field is not a list. """ if merge_with_target and overwrite_target: raise ValueError("Can't merge with and overwrite a target field at the same time") @@ -126,20 +155,19 @@ def _add_field_to( def _add_field_to_silent_fail(*args, **kwargs) -> None | str: """ - Adds a field to an object, ignoring the FieldExistsWarning if the field already exists. Is only needed in the - add_batch_to map function. Without this, the map would terminate early. + Adds a field to an object, ignoring the FieldExistsWarning if the field already exists. + Is only needed in the add_batch_to map function. Without this, the map would terminate early. - Parameters: + Parameters + ---------- args: tuple Positional arguments to pass to the add_field_to function. kwargs: dict Keyword arguments to pass to the add_field_to function. - Returns: + Returns + ------- The field that was attempted to be added, if the field already exists. - - Raises: - FieldExistsWarning: If the field already exists, but this warning is caught and ignored. """ try: _add_field_to(*args, **kwargs) @@ -151,33 +179,42 @@ def _add_field_to_silent_fail(*args, **kwargs) -> None | str: def add_fields_to( event: dict, fields: dict, - rule: "Rule" = None, + rule: Optional["Rule"] = None, merge_with_target: bool = False, overwrite_target: bool = False, + skip_none: bool = True, ) -> None: """ - Handles the batch addition operation while raising a FieldExistsWarning with all unsuccessful targets. + Handles the batch addition operation while raising a FieldExistsWarning with + all unsuccessful targets. - Parameters: + Parameters + ---------- event: dict The event object to which fields are to be added. fields: dict - A dict with key value pairs describing the fields that should be added. The key is the dotted subfield - string indicating the target. The value is the content that should be added to the named target. The - content can be of type: str, float, int, list, dict. - rule: Rule + A dict with key value pairs describing the fields that should be added. + The key is the dotted subfield string indicating the target. + The value is the content that should be added to the named target. + The content can be of type: str, float, int, list, dict. + rule: Rule, optional A rule that initiated the field addition, is used for proper error handling. - merge_with_target: bool + merge_with_target: bool, optional A boolean indicating whether to merge if the target field already exists. - overwrite_target: bool + Defaults to False. + overwrite_target: bool, optional A boolean indicating whether to overwrite the target field if it already exists. + Defaults to False. + skip_none: bool, optional + A boolean indicating whether to filter out None-valued fields. Defaults to True. - Raises: - FieldExistsWarning: If there are targets to which the content could not be added due to field - existence restrictions. + Raises + ------ + FieldExistsWarning: If there are targets to which the content could not be added due to + field existence restrictions. """ # filter out None values - fields = {key: value for key, value in fields.items() if value is not None} + fields = {key: value for key, value in fields.items() if not skip_none or value is not None} number_fields = len(dict(fields)) if number_fields == 1: _add_field_to(event, list(fields.items())[0], rule, merge_with_target, overwrite_target) @@ -195,50 +232,166 @@ def add_fields_to( raise FieldExistsWarning(rule, event, unsuccessful_targets_resolved) -def _get_slice_arg(slice_item): +def _get_slice_arg(slice_item) -> int | None: return int(slice_item) if slice_item else None -def _get_item(items, item): +def _get_item(container: FieldValue, key: str) -> FieldValue: + """ + Retrieves the value associated with given key from the container. + + This function supports: + - Getting a value by name from a dict ({ "K": X }, "K") -> X + - Getting a value by index from a list ([X, Y, Z], "1") -> Y + - Getting a value by slice from a list ([X, Y, Z], "1:") -> [Y, Z] + + The retrieved value itself can be a container type, + thus this function can be used to traverse a nested data structure. + + Parameters + ---------- + container : FieldValue + Container object where data is read from + key : str + Dictionary key, index or slice spec refering to the container + + Returns + ------- + FieldValue + The container value which is referenced by the key + + Raises + ------ + KeyError + The container is a dict, but key does not exist in it + IndexError + The container is a list, but key does not represent a valid index in it + ValueError + The container is not a dict, but key is neither slice nor integer index + TypeError + The key is not a valid slice or the container is neither a dict nor a list + """ try: - return dict.__getitem__(items, item) + return dict.__getitem__(cast(dict[str, FieldValue], container), key) except TypeError: - if ":" in item: - slice_args = map(_get_slice_arg, item.split(":")) - item = slice(*slice_args) + index_or_slice: slice | int + if ":" in key: + slice_args = map(_get_slice_arg, key.split(":")) + index_or_slice = slice(*slice_args) else: - item = int(item) - return list.__getitem__(items, item) + index_or_slice = int(key) + return list.__getitem__(cast(list[FieldValue], container), index_or_slice) -def get_dotted_field_value(event: dict, dotted_field: str) -> Optional[Union[dict, list, str]]: +def get_dotted_field_value(event: dict[str, FieldValue], dotted_field: str) -> FieldValue: """ Returns the value of a requested dotted_field by iterating over the event dictionary until the field was found. In case the field could not be found None is returned. Parameters ---------- - event: dict + event: dict[str, FieldValue] The event from which the dotted field value should be extracted dotted_field: str The dotted field name which identifies the requested value Returns ------- - dict_: dict, list, str - The value of the requested dotted field. + FieldValue + The value of the requested dotted field, which can be None. + None is also returnd when the field could not be found and silent_fail is True. + + Raises + ------ + KeyError, ValueError, TypeError, IndexError + Different errors which can be raised on missing fields and silent_fail is False. """ + current: FieldValue = event try: for field in get_dotted_field_list(dotted_field): - event = _get_item(event, field) - return event + current = _get_item(current, field) + return current except (KeyError, ValueError, TypeError, IndexError): return None +def get_dotted_field_value_with_explicit_missing( + event: dict[str, FieldValue], dotted_field: str +) -> FieldValue | Missing: + """ + Returns the value of a requested dotted_field by iterating over the event dictionary until the + field was found. In case the field could not be found None is returned. + + Parameters + ---------- + event: dict[str, FieldValue] + The event from which the dotted field value should be extracted + dotted_field: str + The dotted field name which identifies the requested value + + Returns + ------- + FieldValue | Missing + The value of the requested dotted field, which can be None. + None is also returnd when the field could not be found and silent_fail is True. + + Raises + ------ + KeyError, ValueError, TypeError, IndexError + Different errors which can be raised on missing fields and silent_fail is False. + """ + current: FieldValue = event + try: + for field in get_dotted_field_list(dotted_field): + current = _get_item(current, field) + return current + except (KeyError, ValueError, TypeError, IndexError): + return MISSING + + +def get_dotted_field_values( + event: dict, + dotted_fields: Iterable[str], + on_missing: Callable[[str], FieldValue | Skip] = lambda _: None, +) -> dict[str, FieldValue]: + """ + Extract the subset of fields from the dict by using the list of (potentially dotted) + field names as an allow list. + The behavior for fields targeted by the list but missing in the dict can be controlled + by a callback. + The callback allows for providing a replacement value, or - by returning SKIP - can + instruct the method to omit the field entirely from the extracted dict. + + + Parameters + ---------- + event : dict + The (potentially nested) dict where the values are sourced from + dotted_fields : Iterable[str] + The (potentially dotted) list of field names to extract + on_missing : Callable[[str], FieldValue | Skip], optional + The callback to control the behavior for missing fields, by default + `lambda _: None` which returns missing fields with `None` value + + Returns + ------- + dict[str, FieldValue] + The (potentially nested) sub-dict + """ + result: dict[str, FieldValue] = {} + for field_to_copy in dotted_fields: + value = get_dotted_field_value_with_explicit_missing(event, field_to_copy) + if value is MISSING: + value = on_missing(field_to_copy) + if value is SKIP: + continue + result[field_to_copy] = value + return result + + @lru_cache(maxsize=100000) def get_dotted_field_list(dotted_field: str) -> list[str]: - """make lookup of dotted field in the dotted_field_lookup_table and ensures + """Make lookup of dotted field in the dotted_field_lookup_table and ensures it is added if not found. Additionally, the string will be interned for faster followup lookups. @@ -255,7 +408,7 @@ def get_dotted_field_list(dotted_field: str) -> list[str]: return dotted_field.split(".") -def pop_dotted_field_value(event: dict, dotted_field: str) -> Optional[Union[dict, list, str]]: +def pop_dotted_field_value(event: dict, dotted_field: str) -> FieldValue: """ Remove and return dotted field. Returns None is field does not exist. @@ -340,7 +493,7 @@ def remove_file_if_exists(test_output_path): def camel_to_snake(camel: str) -> str: - """ensures that the input string is snake_case""" + """Ensures that the input string is snake_case""" _underscorer1 = re.compile(r"(.)([A-Z][a-z]+)") _underscorer2 = re.compile("([a-z0-9])([A-Z])") @@ -350,7 +503,7 @@ def camel_to_snake(camel: str) -> str: def snake_to_camel(snake: str) -> str: - """ensures that the input string is CamelCase""" + """Ensures that the input string is CamelCase""" components = snake.split("_") if len(components) == 1: @@ -364,13 +517,58 @@ def snake_to_camel(snake: str) -> str: append_as_list = partial(add_fields_to, merge_with_target=True) +def copy_fields_to_event( + target_event: dict, + source_event: dict, + dotted_field_names: Iterable[str], + *, + skip_missing: bool = True, + merge_with_target: bool = False, + overwrite_target: bool = False, + rule: Optional["Rule"] = None, +) -> None: + """ + Copies fields from source_event to target_event. + The function behaves similar to add_fields_to. + + Parameters + ---------- + target_event : dict + The field dictionary where fields are being added to in-place + source_event : dict + The field dictionary where field values are being read from + dotted_field_names : Iterable[str] + The list of (potentially dotted) field names to copy + skip_missing : bool, optional + Controls whether missing fields should be skipped or defaulted to None, by default True + merge_with_target : bool, optional + Controls whether already existing fields should be merged as a list, by default False + overwrite_target : bool, optional + Controls whether already existing fields should be overwritten, by default False + rule : Rule, optional + Contextual info for error handling, by default None + """ + on_missing_result = SKIP if skip_missing else None + source_fields = get_dotted_field_values( + source_event, dotted_field_names, on_missing=lambda _: on_missing_result + ) + add_fields_to( + target_event, + source_fields, + rule=rule, + overwrite_target=overwrite_target, + merge_with_target=merge_with_target, + skip_none=False, + ) + + def add_and_overwrite(event, fields, rule, *_): - """wrapper for add_field_to""" + """Wrapper for add_field_to""" add_fields_to(event, fields, rule, overwrite_target=True) def append(event, field, separator, rule): - """appends to event""" + """Appends to event""" target_field, content = list(field.items())[0] target_value = get_dotted_field_value(event, target_field) if not isinstance(target_value, list): @@ -382,7 +580,7 @@ def append(event, field, separator, rule): def get_source_fields_dict(event, rule): - """returns a dict with dotted fields as keys and target values as values""" + """Returns a dict with dotted fields as keys and target values as values""" source_fields = rule.source_fields source_field_values = map(partial(get_dotted_field_value, event), source_fields) source_field_dict = dict(zip(source_fields, source_field_values)) diff --git a/pyproject.toml b/pyproject.toml index 3a9e20e1c..8c4a21c99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ dev = [ "asgiref", "pytest-asyncio", "pre-commit", - "mypy", + "mypy>=1.18.2", "types-requests", "psutil", "types-psutil", diff --git a/tests/unit/ng/processor/pre_detector/test_pre_detector.py b/tests/unit/ng/processor/pre_detector/test_pre_detector.py index 51c0e81d7..280f96006 100644 --- a/tests/unit/ng/processor/pre_detector/test_pre_detector.py +++ b/tests/unit/ng/processor/pre_detector/test_pre_detector.py @@ -1,9 +1,11 @@ # pylint: disable=missing-module-docstring # pylint: disable=protected-access +# pylint: disable=duplicate-code import re from copy import deepcopy import pytest +from deepdiff import DeepDiff from logprep.ng.event.log_event import LogEvent from logprep.ng.event.sre_event import SreEvent @@ -25,18 +27,15 @@ def test_perform_successful_pre_detection(self): document = {"winlog": {"event_id": 123, "event_data": {"ServiceName": "VERY BAD"}}} expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_ONE_ID", - "title": "RULE_ONE", - "severity": "critical", - "mitre": ["attack.test1", "attack.test2"], - "case_condition": "directly", - "description": "Test rule one", - "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_ONE_ID", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "description": "Test rule one", + "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -50,18 +49,15 @@ def test_perform_pre_detection_that_fails_if_filter_children_were_slots(self): expected = deepcopy(document) event = LogEvent(document, original=b"") expected_detection_results = [ - ( - { - "case_condition": "directly", - "description": "Test rule four", - "id": "RULE_FOUR_ID", - "mitre": ["attack.test1", "attack.test2"], - "rule_filter": '(A:"*bar*" AND NOT ((A:"foo*" AND A:"*baz")))', - "severity": "critical", - "title": "RULE_FOUR", - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "case_condition": "directly", + "description": "Test rule four", + "id": "RULE_FOUR_ID", + "mitre": ["attack.test1", "attack.test2"], + "rule_filter": '(A:"*bar*" AND NOT ((A:"foo*" AND A:"*baz")))', + "severity": "critical", + "title": "RULE_FOUR", + }, ] event = self.object.process(event) _ = event.extra_data[0] @@ -79,19 +75,16 @@ def test_perform_successful_pre_detection_with_host_name(self): } expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_ONE_ID", - "title": "RULE_ONE", - "severity": "critical", - "mitre": ["attack.test1", "attack.test2"], - "case_condition": "directly", - "host": {"name": "Test hostname"}, - "description": "Test rule one", - "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_ONE_ID", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "host": {"name": "Test hostname"}, + "description": "Test rule one", + "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -101,18 +94,15 @@ def test_perform_successful_pre_detection_with_same_existing_pre_detection(self) document = {"winlog": {"event_id": 123, "event_data": {"ServiceName": "VERY BAD"}}} expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_ONE_ID", - "title": "RULE_ONE", - "severity": "critical", - "mitre": ["attack.test1", "attack.test2"], - "case_condition": "directly", - "description": "Test rule one", - "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_ONE_ID", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "description": "Test rule one", + "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long + }, ] document["pre_detection_id"] = "11fdfc1f-8e00-476e-b88f-753d92af989c" @@ -124,19 +114,16 @@ def test_perform_successful_pre_detection_with_pre_detector_complex_rule_suceeds document = {"tags": "test", "process": {"program": "test"}, "message": "test1*xyz"} expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_TWO_ID", - "title": "RULE_TWO", - "severity": "critical", - "mitre": [], - "case_condition": "directly", - "description": "Test rule two", - "rule_filter": '(tags:"test" AND process.program:"test" AND ' - '(message:"test1*xyz" OR message:"test2*xyz"))', - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_TWO_ID", + "title": "RULE_TWO", + "severity": "critical", + "mitre": [], + "case_condition": "directly", + "description": "Test rule two", + "rule_filter": '(tags:"test" AND process.program:"test" AND ' + '(message:"test1*xyz" OR message:"test2*xyz"))', + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -146,19 +133,16 @@ def test_perform_successful_pre_detection_with_pre_detector_complex_rule_succeed document = {"tags": "test2", "process": {"program": "test"}, "message": "test2Xxyz"} expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_THREE_ID", - "title": "RULE_THREE", - "severity": "critical", - "mitre": [], - "case_condition": "directly", - "description": "Test rule three", - "rule_filter": '(tags:"test2" AND process.program:"test" AND ' - '(message:"test1*xyz" OR message:"test2?xyz"))', - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_THREE_ID", + "title": "RULE_THREE", + "severity": "critical", + "mitre": [], + "case_condition": "directly", + "description": "Test rule three", + "rule_filter": '(tags:"test2" AND process.program:"test" AND ' + '(message:"test1*xyz" OR message:"test2?xyz"))', + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -168,30 +152,24 @@ def test_perform_successful_pre_detection_with_two_rules(self): document = {"first_match": "something", "second_match": "something"} expected = deepcopy(document) expected_detection_results = [ - ( - { - "case_condition": "directly", - "id": "RULE_TWO_ID", - "mitre": ["attack.test2", "attack.test4"], - "description": "Test two rules two", - "rule_filter": '"second_match": *', - "severity": "suspicious", - "title": "RULE_TWO", - }, - ({"kafka": "pre_detector_alerts"},), - ), - ( - { - "case_condition": "directly", - "id": "RULE_ONE_ID", - "mitre": ["attack.test1", "attack.test2"], - "description": "Test two rules one", - "rule_filter": '"first_match": *', - "severity": "critical", - "title": "RULE_ONE", - }, - ({"kafka": "pre_detector_alerts"},), - ), + { + "case_condition": "directly", + "id": "RULE_ONE_ID", + "mitre": ["attack.test1", "attack.test2"], + "description": "Test two rules one", + "rule_filter": '"first_match": *', + "severity": "critical", + "title": "RULE_ONE", + }, + { + "case_condition": "directly", + "id": "RULE_TWO_ID", + "mitre": ["attack.test2", "attack.test4"], + "description": "Test two rules two", + "rule_filter": '"second_match": *', + "severity": "suspicious", + "title": "RULE_TWO", + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -275,18 +253,15 @@ def test_ignores_case(self): document = {"tags": "test", "process": {"program": "test"}, "message": "TEST2*xyz"} expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_TWO_ID", - "title": "RULE_TWO", - "severity": "critical", - "mitre": [], - "case_condition": "directly", - "description": "Test rule two", - "rule_filter": '(tags:"test" AND process.program:"test" AND (message:"test1*xyz" OR message:"test2*xyz"))', # pylint: disable=line-too-long - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_TWO_ID", + "title": "RULE_TWO", + "severity": "critical", + "mitre": [], + "case_condition": "directly", + "description": "Test rule two", + "rule_filter": '(tags:"test" AND process.program:"test" AND (message:"test1*xyz" OR message:"test2*xyz"))', # pylint: disable=line-too-long + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -296,18 +271,15 @@ def test_ignores_case_list(self): document = {"tags": "test", "process": {"program": "test"}, "message": ["TEST2*xyz"]} expected = deepcopy(document) expected_detection_results = [ - ( - { - "id": "RULE_TWO_ID", - "title": "RULE_TWO", - "severity": "critical", - "mitre": [], - "case_condition": "directly", - "description": "Test rule two", - "rule_filter": '(tags:"test" AND process.program:"test" AND (message:"test1*xyz" OR message:"test2*xyz"))', # pylint: disable=line-too-long - }, - ({"kafka": "pre_detector_alerts"},), - ) + { + "id": "RULE_TWO_ID", + "title": "RULE_TWO", + "severity": "critical", + "mitre": [], + "case_condition": "directly", + "description": "Test rule two", + "rule_filter": '(tags:"test" AND process.program:"test" AND (message:"test1*xyz" OR message:"test2*xyz"))', # pylint: disable=line-too-long + }, ] event = LogEvent(document, original=b"") event = self.object.process(event) @@ -331,14 +303,16 @@ def _assert_equality_of_results( assert result_pre_detection_id is not None assert pre_detection_id == result_pre_detection_id - sorted_detection_results = sorted( - [(frozenset(sre_event.data), sre_event.outputs) for sre_event in event.extra_data] - ) - sorted_expected_detection_results = sorted( - [(frozenset(result[0]), result[1]) for result in expected_detection_results] - ) + for detection_result, expected_detection_result in zip( + detection_results, expected_detection_results + ): + diff = DeepDiff( + detection_result, + expected_detection_result, + exclude_paths=["root['id']", "root['rule_filter']"], + ) - assert sorted_detection_results == sorted_expected_detection_results + assert not diff def test_adds_timestamp_to_extra_data_if_provided_by_event(self): rule = { @@ -501,3 +475,188 @@ def test_generate_detection_result_does_not_modify_rule_data(self): assert ( "rule_filter" not in self.object.rules[0].detection_data ), "rule_filter should not be in detection data" + + @pytest.mark.parametrize( + "extra_rule_config, event_data, expected_extra_fields_in_output", + [ + pytest.param( + {}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {"host": {"name": "Test hostname"}}, + id="optional with default host.name", + ), + pytest.param( + {"copy_fields_to_detection_event": set()}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {}, + id="empty set is allowed", + ), + pytest.param( + {"copy_fields_to_detection_event": []}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {}, + id="empty list is allowed", + ), + pytest.param( + {"copy_fields_to_detection_event": ["custom", "winlog.custom"]}, + { + "host": {"name": "Test hostname"}, + "custom": "test toplevel", + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + "custom": "test nested", + }, + }, + {"custom": "test toplevel", "winlog": {"custom": "test nested"}}, + id="copy plain and nested field with list", + ), + pytest.param( + {"copy_fields_to_detection_event": {"custom", "winlog.custom"}}, + { + "host": {"name": "Test hostname"}, + "custom": "test toplevel", + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + "custom": "test nested", + }, + }, + {"custom": "test toplevel", "winlog": {"custom": "test nested"}}, + id="copy plain and nested field with set", + ), + pytest.param( + {"copy_fields_to_detection_event": {"int", "float", "str", "dict", "list"}}, + { + "dict": {"name": "Test hostname"}, + "list": [1, 2, 3], + "str": "test toplevel", + "int": 123, + "float": 12.0, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + "custom": "test nested", + }, + }, + { + "dict": {"name": "Test hostname"}, + "list": [1, 2, 3], + "str": "test toplevel", + "int": 123, + "float": 12.0, + }, + id="copy fields with different types", + ), + pytest.param( + {"copy_fields_to_detection_event": {"host.name", "custom"}}, + { + "host": {"name": None}, + "custom": 0, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + { + "host": {"name": None}, + "custom": 0, + }, + id="None-valued fields not skipped", + ), + pytest.param( + {"copy_fields_to_detection_event": {"host.name", "custom"}}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {"host": {"name": "Test hostname"}}, + id="missing fields skipped", + ), + ], + ) + def test_copy_fields_to_detection_event_matrix( + self, extra_rule_config: dict, event_data: dict, expected_extra_fields_in_output: dict + ): + self._load_rule( + { + "filter": "*", + "pre_detector": { + "id": "ac1f47e4-9f6f-4cd4-8738-795df8bd5d4f", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + **extra_rule_config, + }, + "description": "Test rule one", + } + ) + expected_detection_results = [ + { + "id": "RULE_ONE_ID", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "description": "Test rule one", + "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long + **expected_extra_fields_in_output, + }, + ] + event = LogEvent(event_data, original=b"") + event = self.object.process(event) + self._assert_equality_of_results(event, event_data, expected_detection_results) + + @pytest.mark.parametrize( + "field_name", + [ + "rule_filter", + "description", + "pre_detection_id", + "id", + "title", + "severity", + "mitre", + "case_condition", + "link", + ], + ) + def test_copy_fields_to_detection_event_fails_on_illegal_fields(self, field_name: str): + with pytest.raises(ValueError, match="Illegal fields") as exc_info: + self._load_rule( + { + "filter": "*", + "pre_detector": { + "id": "ac1f47e4-9f6f-4cd4-8738-795df8bd5d4f", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "copy_fields_to_detection_event": {field_name}, + }, + "description": "Test rule one", + } + ) + assert exc_info is not None diff --git a/tests/unit/processor/pre_detector/test_pre_detector.py b/tests/unit/processor/pre_detector/test_pre_detector.py index cdb06fa21..441b6ca28 100644 --- a/tests/unit/processor/pre_detector/test_pre_detector.py +++ b/tests/unit/processor/pre_detector/test_pre_detector.py @@ -4,6 +4,7 @@ from copy import deepcopy import pytest +from deepdiff import DeepDiff from tests.unit.processor.base import BaseProcessorTestCase @@ -92,6 +93,32 @@ def test_perform_successful_pre_detection_with_host_name(self): document, expected, detection_results.data, expected_detection_results ) + def test_perform_successful_pre_detection_without_host_name(self): + document = { + # "host": {"name": "Test hostname"}, + "winlog": {"event_id": 123, "event_data": {"ServiceName": "VERY BAD"}}, + } + expected = deepcopy(document) + expected_detection_results = [ + ( + { + "id": "RULE_ONE_ID", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + # "host": {"name": "Test hostname"}, + "description": "Test rule one", + "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long + }, + ({"kafka": "pre_detector_alerts"},), + ) + ] + detection_results = self.object.process(document) + self._assert_equality_of_results( + document, expected, detection_results.data, expected_detection_results + ) + def test_perform_successful_pre_detection_with_same_existing_pre_detection(self): document = {"winlog": {"event_id": 123, "event_data": {"ServiceName": "VERY BAD"}}} expected = deepcopy(document) @@ -166,18 +193,6 @@ def test_perform_successful_pre_detection_with_two_rules(self): document = {"first_match": "something", "second_match": "something"} expected = deepcopy(document) expected_detection_results = [ - ( - { - "case_condition": "directly", - "id": "RULE_TWO_ID", - "mitre": ["attack.test2", "attack.test4"], - "description": "Test two rules two", - "rule_filter": '"second_match": *', - "severity": "suspicious", - "title": "RULE_TWO", - }, - ({"kafka": "pre_detector_alerts"},), - ), ( { "case_condition": "directly", @@ -190,6 +205,18 @@ def test_perform_successful_pre_detection_with_two_rules(self): }, ({"kafka": "pre_detector_alerts"},), ), + ( + { + "case_condition": "directly", + "id": "RULE_TWO_ID", + "mitre": ["attack.test2", "attack.test4"], + "description": "Test two rules two", + "rule_filter": '"second_match": *', + "severity": "suspicious", + "title": "RULE_TWO", + }, + ({"kafka": "pre_detector_alerts"},), + ), ] detection_results = self.object.process(document) self._assert_equality_of_results( @@ -319,14 +346,16 @@ def _assert_equality_of_results( assert result_pre_detection_id is not None assert pre_detection_id == result_pre_detection_id - sorted_detection_results = sorted( - [(frozenset(result[0]), result[1]) for result in detection_results] - ) - sorted_expected_detection_results = sorted( - [(frozenset(result[0]), result[1]) for result in expected_detection_results] - ) + for detection_result, expected_detection_result in zip( + detection_results, expected_detection_results + ): + diff = DeepDiff( + detection_result, + expected_detection_result, + exclude_paths=["root[0]['id']", "root[0]['rule_filter']"], + ) - assert sorted_detection_results == sorted_expected_detection_results + assert not diff def test_adds_timestamp_to_extra_data_if_provided_by_event(self): rule = { @@ -464,3 +493,194 @@ def test_appends_processing_warning_if_timestamp_could_not_be_parsed(self): assert "tags" in document assert "_pre_detector_failure" in document["tags"] assert "_pre_detector_timeparsing_failure" in document["tags"] + + @pytest.mark.parametrize( + "extra_rule_config, event_data, expected_extra_fields_in_output", + [ + pytest.param( + {}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {"host": {"name": "Test hostname"}}, + id="optional with default host.name", + ), + pytest.param( + {"copy_fields_to_detection_event": set()}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {}, + id="empty set is allowed", + ), + pytest.param( + {"copy_fields_to_detection_event": []}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {}, + id="empty list is allowed", + ), + pytest.param( + {"copy_fields_to_detection_event": ["custom", "winlog.custom"]}, + { + "host": {"name": "Test hostname"}, + "custom": "test toplevel", + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + "custom": "test nested", + }, + }, + {"custom": "test toplevel", "winlog": {"custom": "test nested"}}, + id="copy plain and nested field with list", + ), + pytest.param( + {"copy_fields_to_detection_event": {"custom", "winlog.custom"}}, + { + "host": {"name": "Test hostname"}, + "custom": "test toplevel", + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + "custom": "test nested", + }, + }, + {"custom": "test toplevel", "winlog": {"custom": "test nested"}}, + id="copy plain and nested field with set", + ), + pytest.param( + {"copy_fields_to_detection_event": {"int", "float", "str", "dict", "list"}}, + { + "dict": {"name": "Test hostname"}, + "list": [1, 2, 3], + "str": "test toplevel", + "int": 123, + "float": 12.0, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + "custom": "test nested", + }, + }, + { + "dict": {"name": "Test hostname"}, + "list": [1, 2, 3], + "str": "test toplevel", + "int": 123, + "float": 12.0, + }, + id="copy fields with different types", + ), + pytest.param( + {"copy_fields_to_detection_event": {"host.name", "custom"}}, + { + "host": {"name": None}, + "custom": 0, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + { + "host": {"name": None}, + "custom": 0, + }, + id="None-valued fields not skipped", + ), + pytest.param( + {"copy_fields_to_detection_event": {"host.name", "custom"}}, + { + "host": {"name": "Test hostname"}, + "winlog": { + "event_id": 123, + "event_data": {"ServiceName": "VERY BAD"}, + }, + }, + {"host": {"name": "Test hostname"}}, + id="missing fields skipped", + ), + ], + ) + def test_copy_fields_to_detection_event_matrix( + self, extra_rule_config: dict, event_data: dict, expected_extra_fields_in_output: dict + ): + self._load_rule( + { + "filter": "*", + "pre_detector": { + "id": "ac1f47e4-9f6f-4cd4-8738-795df8bd5d4f", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + **extra_rule_config, + }, + "description": "Test rule one", + } + ) + document = event_data + expected = deepcopy(document) + expected_detection_results = [ + ( + { + "id": "RULE_ONE_ID", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "description": "Test rule one", + "rule_filter": '(winlog.event_id:"123" AND winlog.event_data.ServiceName:"VERY BAD")', # pylint: disable=line-too-long + **expected_extra_fields_in_output, + }, + ({"kafka": "pre_detector_alerts"},), + ) + ] + detection_results = self.object.process(document) + self._assert_equality_of_results( + document, expected, detection_results.data, expected_detection_results + ) + + @pytest.mark.parametrize( + "field_name", + [ + "rule_filter", + "description", + "pre_detection_id", + "id", + "title", + "severity", + "mitre", + "case_condition", + "link", + ], + ) + def test_copy_fields_to_detection_event_fails_on_illegal_fields(self, field_name: str): + with pytest.raises(ValueError, match="Illegal fields") as exc_info: + self._load_rule( + { + "filter": "*", + "pre_detector": { + "id": "ac1f47e4-9f6f-4cd4-8738-795df8bd5d4f", + "title": "RULE_ONE", + "severity": "critical", + "mitre": ["attack.test1", "attack.test2"], + "case_condition": "directly", + "copy_fields_to_detection_event": {field_name}, + }, + "description": "Test rule one", + } + ) + assert exc_info is not None