diff --git a/.dockerignore b/.dockerignore index b46d4cbc..f17bf4d6 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,14 +5,17 @@ .pytest_cache .ruff_cache .vscode +__pycache__ +build/ dist/ docs/ examples/ figures/ logs/ -mlruns/ +netsecgame.egg-info/ notebooks/ NetSecGameAgents/ +site/ tests/ trajectories/ readme_images/ diff --git a/AIDojoCoordinator/docs/Components.md b/AIDojoCoordinator/docs/Components.md deleted file mode 100644 index bbe13c3d..00000000 --- a/AIDojoCoordinator/docs/Components.md +++ /dev/null @@ -1,97 +0,0 @@ -# Game Components -Here, you can see the details of all components of the NetSetEnvironment and their usage. These components are located in [game_components.py](game_components.py). - -## Building blocks -The following classes are used in the game to hold information about the state of the game. They are used both in the [Actions](#actions) and [GameState](#gamestate). - -### IP -IP is immutable object that represents an IPv4 object in the NetSecGame. It has a single parameter of the address in a dot-decimal notation (4 octet represeted as decimal value separeted by dots). - -Example: `ip = IP("192.168.1.1")` - -### Network -Network is immutable object that represents an IPv4 network object in the NetSecGame. It has 2 parameters: -- `network_ip:str` representing the IPv4 address of the network. -- `mask:int` representing the mask in the CIDR notation. - -Example: `net = Network("192.168.1.0", 24)` - -## Service -Service class holds information about services running in hosts. Each Service has four parameters: -- `name`:str - Name of the service (e.g., "SSH") -- `type`:str - `passive` or `active`. Currently not being used. -- `version`:str - version of the service. -- `is_local`:bool - flag specifying if the service is local only. (if `True`, service is NOT visible without controlling the host). - -Example: `s = Service('postgresql', 'passive', '14.3.0', False)` - -## Data -Data class holds information about datapoints (files) present in the NetSecGame. Datapoints DO NOT hold the content of files. -Each data instance has two parameters: -- `owner`:str - specifying the user who owns this datapoint -- `id`: str - unique identifier of the datapoint in a host -- `size`: int - size of the datapoint (optional, default=0) -- `type`: str - identification of a type of the file (optional, default="") - -Examples:`Data("User1", "DatabaseData")`, `Data("User1", "DatabaseData", size=42, type="txt")` - -## GameState -GameState is an object that represents a view of the NetSecGame environment in a given state. It is constructed as a collection of 'assets' available to the agent. GameState has following parts: -- `known_networks`: Set of [Network](#network) objects that the agent is aware of -- `known_hosts`: Set of [IP](#ip) objects that the agent is aware of -- `controlled_hosts`: Set of [IP](#ip) objetcs that the agent has control over. Note that `controlled_hosts` is a subset of `known_hosts`. -- `known_services`: Dictionary of services that the agent is aware of. -The dictionary format: {`IP`: {`Service`}} where [IP](#ip) object is a key and the value is a set of [Service](#service) objects located in the `IP`. -- `known_data`: Dictionary of data instances that the agent is aware of. The dictionary format: {`IP`: {`Data`}} where [IP](#ip) object is a key and the value is a set of [Data](#data) objects located in the `IP`. -- `known_blocks`: Dictionary of firewall blocks the agent is aware of. It is a dictionary with format: {`target_IP`: {`blocked_IP`, `blocked_IP`}}. Where `target_IP` is the [IP](#ip) where the FW rule was applied (usually a router) and `blocked_IP` is the IP address that is blocked. For now the blocks happen in both input and output direction simultaneously. - - -## Actions -Actions are the objects sent by the agents to the environment. Each action is evaluated by AIDojo and executed if -1. Is a valid Action -2. Can be processed in the current state of the environment - -In all cases, when an agent sends an action to AIDojo, it is given a response. -### Action format -The Action class is defined in `game_components.py`. It has two basic parts: -1. ActionType:Enum -2. parameters:dict - -ActionType is unique Enum that determines what kind of action is agent playing. Parameters are passed in a dictionary as follows. -### List of actions -- **JoinGame**, params={`agent_info`:AgentInfo(\, \)}: Used to register agent in a game with a given \. -- **QuitGame**, params={}: Used for termination of agent's interaction. -- **ResetGame**, params={`request_trajectory`:`bool`(default=`False`), `randomize_topology`:`bool` (default=`False`)}: Used for requesting reset of the game to it's initial position. If `request_trajectory = True`, the coordinator will send back the complete trajectory of the previous run in the next message. -If `randomize_topology=True`, the agent requests randomization of IPs for the next episode. NOTE: randomization takes place only if all playing agents request it. ---- -- **ScanNetwork**, params{`source_host`:\, `target_network`:\}: Scans the given \ from a specified source host. Discovers ALL hosts in a network that are accessible from \. If successful, returns set of discovered \ objects. -- **FindServices**, params={`source_host`:\, `target_host`:\}: Used to discover ALL services running in the `target_host` if the host is accessible from `source_host`. If successful, returns a set of all discovered \ objects. -- **FindData**, params={`source_host`:\, `target_host`:\}: Searches `target_host` for data. If `source_host` differs from `target_host`, success depends on accessability from the `source_host`. If successful, returns a set of all discovered \ objects. -- **ExploitService**, params={`source_host`:\, `target_host`:\, `taget_service`:\}: Exploits `target_service` in a specified `target_host`. If successful, the attacker gains control of the `target_host`. -- **ExfiltrateData**, params{`source_host`:\, `target_host`:\, `data`:\}: Copies `data` from the `source_host` to `target_host` IF both are controlled and `target_host` is accessible from `source_host`. - -### Action preconditions and effects -In the following table, we describe the effects of selected actions and their preconditions. Note that if the preconditions are not satisfied, the actions's effects are not applied. - -| Action | Params | Preconditions | Effects | -|----------------------|----------------------|----------------------|----------------------| -| ScanNetwork| `source_host`, `target_network`| `source_host` ∈ `controlled_hosts`| extends `known_networks`| -|FindServices| `source_host`, `target_host`| `source_host` ∈ `controlled_hosts`| extends `known_services` AND `known_hosts`| -|FindData| `source_host`, `target_host`| `source_host`, `target_host` ∈ `controlled_hosts`| extends `known_data`| -|Exploit Service | `source_host`, `target_host`, `target_service`|`source_host` ∈ `controlled_hosts`| extends `controlled_hosts` with `target_host`| -ExfiltrateData| `source_host`,`target_host`, `data` |`source_host`, `target_host` ∈ `controlled_hosts` AND `data` ∈ `known_data`| extends `known_data[target_host]` with `data`| - -#### Assumption and Conditions for Actions -1. When playing the `ExploitService` action, it is expected that the agent has discovered this service before (by playing `FindServices` in the `target_host` before this action) -2. The `Find Data` action finds all the available data in the host if successful. -3. The `Find Data` action requires ownership of the target host. -4. Playing `ExfiltrateData` requires controlling **BOTH** source and target hosts -5. Playing `Find Services` can be used to discover hosts (if those have any active services) -6. Parameters of `ScanNetwork` and `FindServices` can be chosen arbitrarily (they don't have to be listed in `known_newtworks`/`known_hosts`) - -## Observations -After submitting Action `a` to the environment, agents receive an `Observation` in return. Each observation consists of 4 parts: -- `state`:`Gamestate` - with the current view of the environment [state](#gamestate) -- `reward`: `int` - with the immediate reward agent gets for playing Action `a` -- `end`:`bool` - indicating if the interaction can continue after playing Action `a` -- `info`: `dict` - placeholder for any information given to the agent (e.g., the reason why `end is True` ) diff --git a/AIDojoCoordinator/docs/Coordinator.md b/AIDojoCoordinator/docs/Coordinator.md deleted file mode 100644 index 28a524fb..00000000 --- a/AIDojoCoordinator/docs/Coordinator.md +++ /dev/null @@ -1,49 +0,0 @@ -# Coordinator -Coordinator is the centerpiece of the game orchestration. It provides an interface between the agents and the AIDojo world. - -1. Registration of new agents in the game -2. Verification of agents' action format -3. Recording (and storing) trajectories of agents -4. Detection of episode ends (either by reaching timout or agents reaching their respective goals) -5. Assigning rewards for each action and at the end of each episode -6. Removing agents from the game -7. Registering the GameReset requests and handelling the game resets. - -## Connction to other game components -Coordinator, having the role of the middle man in all communication between the agent and the world uses several queues for massing passing and handelling. - -1. `Action queue` is a queue in which the agents submit their actions. It provides N:1 communication channel in which the coordinator receives the inputs. -2. `Answer queues` is a separeate queue **per agent** in which the results of the actions are send to the agent. - - -## Main components of the coordinator -`self._actions_queue`: asycnio queue for agents -> coordinator communication -`self._answer_queues`: dictionary of asycnio queues for coordinator -> agent communication (1 queue per agent) -`self._world_action_queue`: asycnio queue for coordinator -> world queue communication -`self._world_response_queue`: asycnio queue for world -> coordinator queue communication -`self.task_config`: Object with the configuration of the scenario -`self.ALLOWED_ROLES`: list of allowed agent roles [`Attacker`, `Defender`, `Benign`] -`self._world`: Instance of `AIDojoWorld`. Implements the dynamics of the world -`self._CONFIG_FILE_HASH`: hash of the configuration file used in the interaction (scenario, topology, etc.). Used for better reproducibility of results -`self._starting_positions_per_role`: dictionary of starting position of each agent type from `self.ALLOWED_ROLES` -`self._win_conditions_per_role`: dictionary of goal state for each agent type from `self.ALLOWED_ROLES` -`self._goal_description_per_role`: dictionary of textual description of goal of each agent type from `self.ALLOWED_ROLES` -`self._steps_limit_per_role`: dictionary of maximum allowed steps per episode for of each agent type from `self.ALLOWED_ROLES` -`self._use_global_defender`: Inditaction of presence of Global defender (deprecated) - -### Agent information components -`self.agents`: information about connected agents {`agent address`: (`agent_name`,`agent_role`)} -`self._agent_steps`: step counter for each agent in the current episode -`self._reset_requests`: dictionary where requests for episode reset are collected (the world resets only if **all** active agents request reset) -`self._randomize_topology_requests`: dictionary where requests for topology randomization are collected (the world randomizes the topology only if **all** active agents request reset) -`self._agent_observations`: current observation per agent -`self._agent_starting_position`: starting position (with wildcards, see [configuration](../README.md#task-configuration)) per agent -`self._agent_states`: current GameState per agent -`self._agent_last_action`: last Action per agent -`self._agent_statuses`: status of each agent. One of AgentStatus -`self._agent_rewards`: dictionary of final reward of each agent in the current episod. Only agent's which can't participate in the ongoing episode are listed. -`self._agent_trajectories`: complete trajectories for each agent in the ongoing episode - - -## Episode -The episode starts with sufficient amount of agents registering in the game. Each agent role has a maximum allowed number of steps defined in the task configuration. An episode ends if all agents reach the goal \ No newline at end of file diff --git a/AIDojoCoordinator/docs/Trajectory_analysis.md b/AIDojoCoordinator/docs/Trajectory_analysis.md deleted file mode 100644 index e7c330de..00000000 --- a/AIDojoCoordinator/docs/Trajectory_analysis.md +++ /dev/null @@ -1,31 +0,0 @@ -# Trajectories and Trajectory analusis -Trajectories capture interactions of agents in AI Dojo. They can be stored in a file for future analysis using the configuration option `save_trajectories: True` in `env` section of the task configuration file. Trajectories are stored in a JSON format, one JSON object per line using [jsonlines](https://jsonlines.readthedocs.io/en/latest/). - -### Example of the trajectory -Below we show an example of a trajectory consisting only from 1 step. Starting from state *S1*, the agent takes action*A1* and moves to state *S2* and is awarded with immediate reward `r = -1`: -```json -{ - "agent_name": "ExampleAgent", - "agent_role": "Attacker", - "end_reason": "goal_reached", - "trajectory": - { - "states":[ - "", - "" - ], - "actions":[ - "" - ], - "rewards":[-1] - } -} -``` -`agent_name` and `agent_role` are provided by the agent upon registration in the game. `end_reason` identifies how did the episode end. Currently there are four options: -1. `goal_reached` - the attacker succcessfully reached the goal state and won the game -2. `detected` - the attacker was detected by the defender subsequently lost the game -3. `max_steps` - the agent used the max allowed amount of steps and the episode was terminated -4. `None` - the episode was interrupted before ending and the trajectory is incomplete. - -## Trajectory analysis - diff --git a/AIDojoCoordinator/docs/figures/architecture_diagram.jpg b/AIDojoCoordinator/docs/figures/architecture_diagram.jpg deleted file mode 100644 index 7c6db506..00000000 Binary files a/AIDojoCoordinator/docs/figures/architecture_diagram.jpg and /dev/null differ diff --git a/AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg b/AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg deleted file mode 100644 index 703ef42f..00000000 Binary files a/AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg and /dev/null differ diff --git a/AIDojoCoordinator/worlds/__init__.py b/AIDojoCoordinator/worlds/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/Dockerfile b/Dockerfile index 3be8548a..58a470ae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM python:3.12.10-slim # Set the working directory in the container -ENV DESTINATION_DIR=/aidojo +ENV DESTINATION_DIR=/netsecgame # Install system dependencies @@ -20,13 +20,13 @@ WORKDIR ${DESTINATION_DIR} # Install any necessary Python dependencies # If a requirements.txt file is in the repository -RUN if [ -f pyproject.toml ]; then pip install . ; fi +RUN if [ -f pyproject.toml ]; then pip install .[server] ; fi # Expose the port the coordinator will run on EXPOSE 9000 # Run the Python script when the container launches (with default arguments --task_config=netsecenv_conf.yaml --game_port=9000 --game_host=0.0.0.0) -ENTRYPOINT ["python3", "-m", "AIDojoCoordinator.worlds.NSEGameCoordinator", "--task_config=netsecenv_conf.yaml", "--game_port=9000", "--game_host=0.0.0.0"] +ENTRYPOINT ["python3", "-m", "netsecgame.game.worlds.NetSecGame", "--task_config=netsecenv_conf.yaml", "--game_port=9000", "--game_host=0.0.0.0"] # Default command arguments (can be overridden at runtime) CMD ["--debug_level=INFO"] diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..2410222d --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,6 @@ +exclude netsecgame/game/worlds/CYSTCoordinator.py +exclude netsecgame/game/worlds/RealWorldNetSecGame.py +exclude netsecgame/utils/trajectory_analysis.py +exclude netsecgame/utils/actions_parser.py +exclude netsecgame/utils/gamaplay_graphs.py +exclude netsecgame/utils/log_parser.py diff --git a/NetSecGameAgents b/NetSecGameAgents index 3fa45825..8ac0af82 160000 --- a/NetSecGameAgents +++ b/NetSecGameAgents @@ -1 +1 @@ -Subproject commit 3fa45825b466e9555aa466e4126b3adc1e9c668d +Subproject commit 8ac0af82d6d9769e73aab3926ff98ee996ed66d2 diff --git a/README.md b/README.md index ab640482..6d42ed0f 100755 --- a/README.md +++ b/README.md @@ -8,42 +8,47 @@ The NetSecGame (Network Security Game) is a framework for training and evaluation of AI agents in network security tasks (both offensive and defensive). It is built with [CYST](https://pypi.org/project/cyst/) network simulator and enables rapid development and testing of AI agents in highly configurable scenarios. Examples of implemented agents can be seen in the submodule [NetSecGameAgents](https://github.com/stratosphereips/NetSecGameAgents/tree/main). ## Installation Guide -It is recommended to install the NetSecGame in a virtual environment: -### Python venv -1. +It is recommended to run the environment in the Docker container. The up-to-date image can be found in [Dockerhub](https://hub.docker.com/r/stratosphereips/netsecgame). +```bash +docker pull stratosphereips/netsecgame +``` +#### Building the image locally +Optionally, you can build the image locally with: +```bash +docker build -t netsecgame:local . +``` + +### Installing from source +In case you need to modify the envirment and run directly, we recommed to insall it in a virtual environemnt (Python vevn or Conda): +#### Python venv +1. Create new virtual environment ```bash python -m venv ``` -2. +2. Activate newly created venv ```bash source /bin/activate ``` -### Conda -1. +#### Conda +1. Create new conda environment ```bash conda create --name aidojo python==3.12 ``` -2. +2. Activate newly created conda env ```bash conda activate aidojo ``` -After the virtual environment is activated, install using pip: +### After preparing virutual environment, install using pip: ```bash pip install -e . ``` -### With Docker -The NetSecGame can be run in a Docker container. You can build the image locally with: -```bash -docker build -t aidojo-nsg-coordinator:latest . -``` -or use the available image from [Dockerhub](https://hub.docker.com/r/stratosphereips/netsecgame). -```bash -docker pull stratosphereips/netsecgame -``` + ## Quick Start -A task configuration needs to be specified to start the NetSecGame (see [Configuration](configuration.md)). For the first step, the example task configuration is recommended: +A task configuration YAML file is required for starting the NetSecGame environment. For the first step, the example task configuration is recommended: + +### Example Configuration ```yaml # Example of the task configuration for NetSecGame # The objective of the Attacker in this task is to locate specific data @@ -51,11 +56,11 @@ A task configuration needs to be specified to start the NetSecGame (see [Configu # The scenario starts AFTER the initial breach of the local network # (the attacker controls 1 local device + the remote C&C server). -coordinator: - agents: +coordinator: + agents: Attacker: # Configuration of 'Attacker' agents - max_steps: 25 - goal: + max_steps: 25 # timout set for the role `Attacker` + goal: # Definition of the goal state description: "Exfiltrate data from Samba server to remote C&C server." is_any_part_of_goal_random: True known_networks: [] @@ -64,10 +69,10 @@ coordinator: known_services: {} known_data: {213.47.23.195: [[User1,DataFromServer1]]} # winning condition known_blocks: {} - start_position: # Defined starting position of the attacker + start_position: # Definition of the starting state (keywords "random" and "all" can be used) known_networks: [] known_hosts: [] - controlled_hosts: [213.47.23.195, random] # + controlled_hosts: [213.47.23.195, random] # keyword 'random' will be replaced by randomly selected IP during initilization known_services: {} known_data: {} known_blocks: {} @@ -92,93 +97,117 @@ coordinator: blocked_ips: {} known_blocks: {} -env: +env: # Environment configuraion scenario: 'two_networks_tiny' # use the smallest topology for this example use_global_defender: False # Do not use global SIEM Defender use_dynamic_addresses: False # Do not randomize IP addresses use_firewall: True # Use firewall save_trajectories: False # Do not store trajectories - required_players: 1 + required_players: 1 # Minimal amount of agents requiered to start the game rewards: # Configurable reward function success: 100 step: -1 fail: -10 false_positive: -5 ``` +For detailed configuration instructions, please refer to the [Configuration Documentation](https://stratosphereips.github.io/NetSecGame/configuration/). -The game can be started with: -```bash -python3 -m AIDojoCoordinator.worlds.NSEGameCoordinator \ - --task_config=./examples/example_config.yaml \ - --game_port=9000 -``` -Upon which the game server is created on `localhost:9000` to which the agents can connect to interact in the NetSecGame. -### Docker Container -When running in the Docker container, the NetSecGame can be started with: +### Starting the NetSecGame +With the configuration ready the environment can be started in selected port +#### In Docker container ```bash -docker run -it --rm \ - -v $(pwd)/examples/example_task_configuration.yaml:/aidojo/netsecenv_conf.yaml \ - -v $(pwd)/logs:/aidojo/logs \ +docker run -d --rm --name nsg-server\ + -v $(pwd)/examples/example_task_configuration.yaml:/netsecgame/netsecenv_conf.yaml \ + -v $(pwd)/logs:/netsecgame/logs \ -p 9000:9000 stratosphereips/netsecgame + --debug_level="INFO" +``` +`--name nsg-server`: specifies the name of the container + +`-v :/netsecgame/netsecenv_conf.yaml` : Mapping of the configuration file + +`-v $(pwd)/logs:/netsecgame/logs`: Mapping of the folder where logs are stored + +` -p :9000`: Mapping of the port in which the server runs + +`--debug_level` is an optional parameter to control the logging level `--debug_level=["DEBUG", "INFO", "WARNING", "CRITICAL"]` (defaul=`"INFO"`): +##### Running on Windows (with Docker desktop) +When running on Windows, Docker desktop is required. +```cmd +docker run -d --rm --name netsecgame-server ^ + -p 9000:9000 ^ + -v "%cd%\examples\example_task_configuration.yaml:/netsecgame/netsecenv_conf.yaml" ^ + -v "%cd%\logs:/netsecgame/logs" ^ + stratosphereips/netsecgame:latest + --debug_level="INFO" ``` -optionally, you can set the logging level with `--debug_level=["DEBUG", "INFO", "WARNING", "CRITICAL"]` (defaul=`"INFO"`): +#### Locally +The environment can be started locally with from the root folder of the repository with following command: ```bash -docker run -it --rm \ - -v $(pwd)/examples/example_task_configuration.yaml:/aidojo/netsecenv_conf.yaml \ - -v $(pwd)/logs:/aidojo/logs \ - -p 9000:9000 stratosphereips/netsecgame \ - --debug_level="WARNING" +python3 -m netsecgame.game.worlds.NetSecGame \ + --task_config=./examples/example_task_configuration.yaml \ + --game_port=9000 + --debug_level="INFO" ``` +Upon which the game server is created on `localhost:9000` to which the agents can connect to interact in the NetSecGame. ## Documentation You can find user documentation at [https://stratosphereips.github.io/NetSecGame/](https://stratosphereips.github.io/NetSecGame/) -## Components of the NetSecGame Environment + +### Components of the NetSecGame Environment The architecture of the environment can be seen [here](docs/Architecture.md). The NetSecGame environment has several components in the following files: ``` -├── AIDojoGameCoordinator/ -| ├── game_coordinator.py -| ├── game_components.py -| ├── global_defender.py -| ├── worlds/ -| ├── NSGCoordinator.py -| ├── NSGRealWorldCoordinator.py -| ├── CYSTCoordinator.py -| ├── scenarios/ -| ├── tiny_scenario_configuration.py -| ├── smaller_scenario_configuration.py -| ├── scenario_configuration.py -| ├── three_net_configuration.py +├── netsecgame/ +| ├── agents/ +| ├── base_agent.py # Basic agent class. Defines the API for agent-server communication +| ├── game/ +| ├── scenarios/ +| ├── tiny_scenario_configuration.py +| ├── smaller_scenario_configuration.py +| ├── scenario_configuration.py +| ├── three_net_scenario.py +| ├── worlds/ +| ├── NetSecGame.py # (NSG) basic simulation +| ├── RealWorldNetSecGame.py # Extension of `NSG` - runs actions in the *network of the host computer* +| ├── CYSTCoordinator.py # Extension of `NSG` - runs simulation in CYST engine. +| ├── WhiteBoxNetSecGame.py # Extension of `NSG` - provides agents with full list of actions upon registration. +| ├── agent_server.py # Agent server implementation +| ├── config_parser.py # NSG task configuration parser +| ├── configuration_manager.py # Helper tool to collect and parse query configuration of the game. +| ├── coordinator.py # Core game server. Not to be run as stand-alone world (see worlds/) +| ├── global_defender.py # Stochastic (non-agentic defender) +| ├── game_components.py # contains basic building blocks of the environment | ├── utils/ | ├── utils.py | ├── log_parser.py | ├── gamaplay_graphs.py | ├── actions_parser.py -``` - -### Directory Details +``` +#### Directory Details - `coordinator.py`: Basic coordinator class. Handles agent communication and coordination. **Does not implement dynamics of the world** and must be extended (see examples in `worlds/`). -- `game_components.py`: Implements a library with objects used in the environment. See [detailed explanation](AIDojoCoordinator/docs/Components.md) of the game components. +- `game_components.py`: Implements a library with objects used in the environment. See [detailed explanation](./docs/game_components.md) of the game components. - `global_defender.py`: Implements a global (omnipresent) defender that can be used to stop agents. Simulation of SIEM. -#### **`worlds/`** +##### **`worlds/`** Modules for different world configurations: -- `NSGCoordinator.py`: Coordinator for the Network Security Game. -- `NSGRealWorldCoordinator.py`: Real-world NSG coordinator (actions are executed in the *real network*). +- `NetSecGame.py`: Coordinator for the Network Security Game. +- `RealWorldNetSecGame.py`: Real-world NSG coordinator (actions are executed in the *real network*). - `CYSTCoordinator.py`: Coordinator for CYST-based simulations (requires CYST running). -#### **`scenarios/`** +##### **`scenarios/`** Predefined scenario configurations: - `tiny_scenario_configuration.py`: A minimal example scenario. - `smaller_scenario_configuration.py`: A compact scenario configuration used for development and rapid testing. - `scenario_configuration.py`: The main scenario configuration. -- `three_net_configuration.py`: Configuration for a three-network scenario. Used for the evaluation of the model overfitting. +- `three_net_scenario.py`: Configuration for a three-network scenario. Used for the evaluation of the model overfitting. + Implements the network game's configuration of hosts, data, services, and connections. It is taken from [CYST](https://pypi.org/project/cyst/). -#### **`utils/`** +##### **`utils/`** Helper modules: - `utils.py`: General-purpose utilities. - `log_parser.py`: Tools for parsing game logs. @@ -188,9 +217,6 @@ Helper modules: The [scenarios](#definition-of-the-network-topology) define the **topology** of a network (number of hosts, connections, networks, services, data, users, firewall rules, etc.) while the [task-configuration](#task-configuration) is to be used for definition of the exact task for the agent in one of the scenarios (with fix topology). - Agents compatible with the NetSecGame are located in a separate repository [NetSecGameAgents](https://github.com/stratosphereips/NetSecGameAgents/tree/main) - - - ### Assumptions of the NetSecGame 1. NetSecGame works with the closed-world assumption. Only the defined entities exist in the simulation. 2. If the attacker does a successful action in the same step that the defender successfully detects the action, the priority goes to the defender. The reward is a penalty, and the game ends. @@ -284,208 +310,10 @@ The system monitors actions and maintains a history of recent ones within the ti This approach ensures that only repeated or excessive behavior is flagged, reducing false positives while maintaining a realistic monitoring system. -## Starting the game -The environment should be created before starting the agents. The properties of the game, the task and the topology can be either read from a local file or via REST request to the GameDashboard. - -#### To start the game with a local configuration file -```python3 -m AIDojoCoordinator.worlds.NSEGameCoordinator --task_config=``` - -#### To start the game with a remotely defined configuration -```python3 -m AIDojoCoordinator.worlds.CYSTCoordinator --service_host= --service_port= ``` - -When created, the environment: -1. reads the configuration file -2. loads the network configuration from the config file -3. reads the defender type from the configuration -4. creates starting position and goal position following the config file -5. starts the game server in a specified address and port ### Interaction with the Environment When the game server is created, [agents](https://github.com/stratosphereips/NetSecGameAgents/tree/main) connect to it and interact with the environment. In every step of the interaction, agents submits an [Action](./AIDojoCoordinator/docs/Components.md#actions) and receive [Observation](./AIDojoCoordinator/docs/Components.md#observations) with `next_state`, `reward`, `is_terminal`, `end`, and `info` values. Once the terminal state or timeout is reached, no more interaction is possible until the agent asks for a game reset. Each agent should extend the `BaseAgent` class in [agents](https://github.com/stratosphereips/NetSecGameAgents/tree/main). - -### Configuration -The NetSecEnv is highly configurable in terms of the properties of the world, tasks, and agent interaction. Modification of the world is done in the YAML configuration file in two main areas: -1. Environment (`env` section) controls the properties of the world (taxonomy of networks, maximum allowed steps per episode, probabilities of action success, etc.) -2. Task configuration defines the agents' properties (starting position, goal) - -#### Environment configuration -The environment part defines the properties of the environment for the task (see the example below). In particular: -- `random_seed` - sets the seed for any random processes in the environment -- `scenario` - sets the scenario (network topology) used in the task (currently, `scenario1_tiny`, `scenario1_small`, `scenario1` and `three_nets` are available) -- `save_tajectories` - if `True`, interaction of the agents is serialized and stored in a file -- `use_dynamic_addresses` - if `True`, the network and IP addresses defined in `scenario` are randomly changed at the beginning of **EVERY** episode (the network topology is kept as defined in the `scenario`. Relations between networks are kept, IPs inside networks are chosen at random based on the network IP and mask) -- `use_firewall` - if `True`, firewall rules defined in `scenario` are used when executing actions. When `False`, the firewall is ignored, and all connections are allowed (Default) -- `use_global_defender` - if `True`, enables global defender, which is part of the environment and can stop interaction of any playing agent. -- `required_players` - Minimum required players for the game to start (default 1) -- `rewards`: - - `success` - sets the reward when the agent reaches the goal (default 100) - - `fail` - sets the reward when the agent does not reach its objective (default -10) - - `step_reward` - sets the reward when the agent does each single step in the game (default -1) -- `actions` - defines the probability of success for every ActionType - -```YAML -env: - random_seed: 'random' - scenario: 'scenario1' - use_global_defender: False - use_dynamic_addresses: False - use_firewall: True - save_trajectories: False - rewards: - win: 100 - step: -1 - loss: -10 - actions: - scan_network: - prob_success: 1.0 - find_services: - prob_success: 1.0 - exploit_service: - prob_success: 1.0 - find_data: - prob_success: 1.0 - exfiltrate_data: - prob_success: 1.0 - block_ip: - prob_success: 1.0 -``` - -#### Task configuration -The task configuration part (section `coordinator[agents]`) defines the starting and goal position of the attacker and the type of defender that is used. - -##### Attacker configuration (`[coordinator][agents][Attacker]`) -Configuration of the attacking agents. Consists of three parts: -1. Goal definition (`goal`) which describes the `GameState` properties that must be fulfilled to award `win` reward to the attacker: - - `known_networks:`(list) - - `known_hosts`(list) - - `controlled_hosts`(list) - - `known_services`(dict) - - `known_data`(dict) - - `known_blocks`(dict) - - Each of the parts can be empty (not part of the goal, exactly defined (e.g., `known_networks: [192.168.1.0/24, 192.168.3.0/24]`) or include the keyword `random` (`controlled_hosts: [213.47.23.195, random]`, `known_data: {213.47.23.195: [random]}`. - Additionally, if `random` keyword is used in the goal definition, - `randomize_goal_every_episode`. If set to `True`, each keyword `random` is replaced with a randomly selected, valid option at the beginning of **EVERY** episode. If set to `False`, randomization is performed only **once** when the environment is -2. Definition of starting position (`start_position`), which describes the `GameState` in which the attacker starts. It consists of: - - `known_networks:`(list) - - `known_hosts`(list) - - `controlled_hosts`(list) - - `known_services`(dict) - - `known_data`(dict) - - `known_blocks`(dict) - - The initial network configuration must assign at least **one** controlled host to the attacker in the network. Any item in `controlled_hosts` is copied to `known_hosts`, so there is no need to include these in both sets. `known_networks` is also extended with a set of **all** networks accessible from the `controlled_hosts` -3. Definition of maximum allowed number of steps: - - `max_steps:`(int): defines the maximum allowed number of steps for attackers in **each** episode. - -Example attacker configuration: -```YAML -coordinator: - agents: - Attacker: - max_steps: 20 - goal: - randomize_goal_every_episode: False - known_networks: [] - known_hosts: [] - controlled_hosts: [] - known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]} - known_data: {213.47.23.195: ["random"]} - known_blocks: {'all_routers': 'all_attackers'} - - start_position: - known_networks: [] - known_hosts: [] - # The attacker must always at least control the CC if the goal is to exfiltrate there - # Example of fixing the starting point of the agent in a local host - controlled_hosts: [213.47.23.195, random] - # Services are defined as a target host where the service must be, and then a description in the form 'name, type, version, is_local' - known_services: {} - known_data: {} - known_blocks: {} -``` - -##### Defender configuration (`[coordinator][agents][Defender]`) -Currently, the defender **is** a separate agent. - -If you want a defender in the game, you must connect a defender agent. For playing without a defender, leave the section empty. - -Example of defender configuration: -```YAML - Defender: - goal: - description: "Block all attackers" - known_networks: [] - known_hosts: [] - controlled_hosts: [] - known_services: {} - known_data: {} - known_blocks: {} - - start_position: - known_networks: [] - known_hosts: [] - controlled_hosts: [all_local] - known_services: {} - known_data: {} - blocked_ips: {} - known_blocks: {} -``` -As in other agents, the description is only a text for the agent, so it can know what is supposed to do to win. In the current implementation, the *Defender* wins, if **NO ATTACKER** reaches their goal. - - -### Definition of the network topology -The network topology and rules are defined using a [CYST](https://pypi.org/project/cyst/) simulator configuration. Cyst defines a complex network configuration, and this environment does not use all Cyst features for now. CYST components currently used are: - -- Server hosts (are a NodeConf in CYST) - - Interfaces, each with one IP address - - Users who can log in to the host - - Active and passive services - - Data in the server - - To which network is connected -- Client host (are of type Node in CYST) - - Interfaces, each with one IP address - - To which network is connected - - Active and passive services, if any - - Data in the client -- Router (are of type RouterConf in CYST) - - Interfaces, each with one IP address - - Networks - - Allowed connections between hosts -- Internet host (as an external router) (are of type Node in RouterConf) - - Interfaces, each with one IP address - - Which host can connect -- Exploits - - Which service is the exploit linked to - -### Scenarios -In the current state, we support a single scenario: Data exfiltration to a remote C&C server. However, extensions can be made by modification of the task configuration. - -#### Data exfiltration to a remote C&C -For the data exfiltration, we support 3 variants. The full scenario contains 5 clients (where the attacker can start) and 5 servers, where the data that is supposed to be exfiltrated can be located. *scenario1_small* is a variant with a single client (the attacker always starts there) and all 5 servers. *scenario1_tiny* contains only a single server with data. The tiny scenario is trivial and intended only for debugging purposes. - - - - - - - -
Scenario 1Scenario 1 - smallScenario 1 -tiny
Scenario 1 - Data exfiltrationScenario 1 - smallScenario 1 - tiny
3-nets scenario
- Scenario 1 - Data exfiltration -
- -### Trajectory storing and analysis -The trajectory is a sequence of GameStates, Actions, and rewards in one run of a game. It contains the complete information of the actions played by the agent, the rewards observed and their effect on the state of the environment. Trajectory visualization and analysis tools are described in [Trajectory analysis tools](./docs/Trajectory_analysis.md) - -Trajectories performed by the agents can be stored in a file using the following configuration: -```YAML -env: - save_trajectories: True -``` -> [!CAUTION] -> Trajectory files can grow very fast. It is recommended to use this feature on evaluation/testing runs only. By default, this feature is not enabled. - ## Testing the environment It is advised that after every change, you test if the env is running correctly by doing diff --git a/docs/agent_server.md b/docs/agent_server.md new file mode 100644 index 00000000..0fba5532 --- /dev/null +++ b/docs/agent_server.md @@ -0,0 +1 @@ +::: netsecgame.game.agent_server.AgentServer \ No newline at end of file diff --git a/docs/architecture.md b/docs/architecture.md index ff2d05f3..1c1fb44e 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -1,11 +1,11 @@ -## NetSecGame Architecture -The Network Security Game(NSG) works as a game server - agents connect to it via TCP sockets and interact with the environment using the standard RL communication loop: Agent submits actinon and recieves new observation of the environment. The NSG supports real-time, highly customizable multi-agent simulations. +# NetSecGame Architecture +The NetSecGame (NSG) works as a game server - agents connect to it via TCP sockets and interact with the environment using the standard RL communication loop: Agent submits action and receives new observation of the environment. The NSG supports real-time, highly customizable multi-agent simulations. ## Game Components The following classes are used in the game to hold information about the state of the game. They are used both in the [Actions](#actions) and [GameState](#gamestate). See the API Reference for [GameComponents](game_components.md) ### Building blocks #### IP -IP is immutable object that represents an IPv4 object in the NetSecGame. It has a single parameter of the address in a dot-decimal notation (4 octet represeted as decimal value separeted by dots). +IP is immutable object that represents an IPv4 object in the NetSecGame. It has a single parameter of the address in a dot-decimal notation (4 octet represented as decimal value separated by dots). Example: ```python @@ -24,9 +24,9 @@ net = Network("192.168.1.0", 24) #### Service Service class holds information about services running in hosts. Each Service has four parameters: - `name`:str - Name of the service (e.g., "SSH") -- `type`:str - `passive` or `active`. Currently not being used. +- `type`:str - `passive`, `active` or `unknown`(default). - `version`:str - version of the service. -- `is_local`:bool - flag specifying if the service is local only. (if `True`, service is NOT visible without controlling the host). +- `is_local`:bool - flag specifying if the service is local only (default=True). (if `True`, service is NOT visible without controlling the host). Example: ```python @@ -52,7 +52,7 @@ d2 = Data("User1", "DatabaseData", size=42, type="txt", "SecretUserDatabase") GameState is an object that represents a view of the NetSecGame environment in a given state. It is constructed as a collection of 'assets' available to the agent. GameState has following parts: - `known_networks`: Set of [Network](#network) objects that the agent is aware of - `known_hosts`: Set of [IP](#ip) objects that the agent is aware of -- `controlled_hosts`: Set of [IP](#ip) objetcs that the agent has control over. Note that `controlled_hosts` is a subset of `known_hosts`. +- `controlled_hosts`: Set of [IP](#ip) objects that the agent has control over. Note that `controlled_hosts` is a subset of `known_hosts`. - `known_services`: Dictionary of services that the agent is aware of. The dictionary format: {`IP`: {`Service`}} where [IP](#ip) object is a key and the value is a set of [Service](#service) objects located in the `IP`. - `known_data`: Dictionary of data instances that the agent is aware of. The dictionary format: {`IP`: {`Data`}} where [IP](#ip) object is a key and the value is a set of [Data](#data) objects located in the `IP`. @@ -79,20 +79,21 @@ The Action consists of two parts - **ScanNetwork**, params{`source_host`:``, `target_network`:``}: Scans the given `` from a specified source host. Discovers ALL hosts in a network that are accessible from ``. If successful, returns set of discovered `` objects. - **FindServices**, params={`source_host`:``, `target_host`:``}: Used to discover ALL services running in the `target_host` if the host is accessible from `source_host`. If successful, returns a set of all discovered `` objects. - **FindData**, params={`source_host`:``, `target_host`:``}: Searches `target_host` for data. If `source_host` differs from `target_host`, success depends on accessability from the `source_host`. If successful, returns a set of all discovered `` objects. -- **ExploitService**, params={`source_host`:``, `target_host`:``, `taget_service`:``}: Exploits `target_service` in a specified `target_host`. If successful, the attacker gains control of the `target_host`. +- **ExploitService**, params={`source_host`:``, `target_host`:``, `target_service`:``}: Exploits `target_service` in a specified `target_host`. If successful, the attacker gains control of the `target_host`. - **ExfiltrateData**, params{`source_host`:``, `target_host`:``, `data`:``}: Copies `data` from the `source_host` to `target_host` IF both are controlled and `target_host` is accessible from `source_host`. +- **BlockIP**, params{`source_host`:``, `target_host`:``, `blocked_host`:``}: Blocks communication from/to `blocked_host` on `target_host`. Requires control of `target_host`. ### Action preconditions and effects In the following table, we describe the effects of selected actions and their preconditions. Note that if the preconditions are not satisfied, the actions's effects are not applied. | Action | Params | Preconditions | Effects | |----------------------|----------------------|----------------------|----------------------| -| ScanNetwork| `source_host`, `target_network`| `source_host` ∈ `controlled_hosts`| extends `known_networks`| -|FindServices| `source_host`, `target_host`| `source_host` ∈ `controlled_hosts`| extends `known_services` AND `known_hosts`| +| ScanNetwork| `source_host`, `target_network`| `source_host` ∈ `controlled_hosts`| extends `known_networks`| +|FindServices| `source_host`, `target_host`| `source_host` ∈ `controlled_hosts`| extends `known_services` AND `known_hosts`| |FindData| `source_host`, `target_host`| `source_host`, `target_host` ∈ `controlled_hosts`| extends `known_data`| -|Exploit Service | `source_host`, `target_host`, `target_service`|`source_host` ∈ `controlled_hosts`| extends `controlled_hosts` with `target_host`| -ExfiltrateData| `source_host`,`target_host`, `data` |`source_host`, `target_host` ∈ `controlled_hosts` AND `data` ∈ `known_data`| extends `known_data[target_host]` with `data`| -|BlockIP | `source_host`, `target_host`, `blockedIP`|`source_host` ∈ `controlled_hosts`| extends `known_blocks[target_host]` with `blockedIP`| +|Exploit Service | `source_host`, `target_host`, `target_service`|`source_host` ∈ `controlled_hosts`| extends `controlled_hosts` with `target_host`| +|ExfiltrateData| `source_host`,`target_host`, `data` |`source_host`, `target_host` ∈ `controlled_hosts` AND `data` ∈ `known_data`| extends `known_data[target_host]` with `data`| +|BlockIP | `source_host`, `target_host`, `blocked_host`|`source_host` ∈ `controlled_hosts`| extends `known_blocks[target_host]` with `blocked_host`| #### Assumption and Conditions for Actions 1. When playing the `ExploitService` action, it is expected that the agent has discovered this service before (by playing `FindServices` in the `target_host` before this action) @@ -100,7 +101,7 @@ ExfiltrateData| `source_host`,`target_host`, `data` |`source_host`, `target_host 3. The `Find Data` action requires ownership of the target host. 4. Playing `ExfiltrateData` requires controlling **BOTH** source and target hosts 5. Playing `Find Services` can be used to discover hosts (if those have any active services) -6. Parameters of `ScanNetwork` and `FindServices` can be chosen arbitrarily (they don't have to be listed in `known_newtworks`/`known_hosts`) +6. Parameters of `ScanNetwork` and `FindServices` can be chosen arbitrarily (they don't have to be listed in `known_networks`/`known_hosts`) ### Observations After submitting Action `a` to the environment, agents receive an `Observation` in return. Each observation consists of 4 parts: diff --git a/docs/configuration.md b/docs/configuration.md index f4289bf6..ae40be5c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -15,6 +15,44 @@ The environment part defines the properties of the environment for the task (see - `save_tajectories` - if `True`, interaction of the agents is serialized and stored in a file - `use_dynamic_addresses` - if `True`, the network and IP addresses defined in `scenario` are randomly changed at the beginning of an episode (the network topology is kept as defined in the `scenario`. Relations between networks are kept, IPs inside networks are chosen at random based on the network IP and mask). The change also depend on the input from the agents: +### Available topologies +There are 5 topologies available in NSG: +
+ +
+ +
+ One network topology +
+ +
+ +
+ Two networks tiny topology +
+ +
+ +
+ Two networks small topology +
+ +
+ +
+ Two networks topology +
+ +
+ +
+ Three networks topology +
+ +
+ +
+ |Task configuration| Agent reset request | Result| |----------------------|----------------------|----------------------| |`use_dynamic_ips = True` | `randomize_topology = True`| Changed topology | @@ -25,39 +63,41 @@ The environment part defines the properties of the environment for the task (see In summary, the topology change (IP randomization) can't change without allowing it in the task configuration. If allowed in the task config YAML, it can still be rejected by the agents. - `use_firewall` - if `True` firewall rules defined in `scenario` are used when executing actions. When `False`, the firewall is ignored, and all connections are allowed (Default) -- `use_global_defender` - if `True`, enables global defendr which is part of the environment and can stop interaction of any playing agent. +- `use_global_defender` - if `True`, enables global defender which is part of the environment and can stop interaction of any playing agent. - `required_players` - Minimum required players for the game to start (default 1) - `rewards`: - `success` - sets reward which agent gets when it reaches the goal (default 100) - `fail` - sets the reward that which agent does not reach it's objective (default -10) - - `step_reward` - sets reward which agent gets for every step taken (default -1) + - `step` - sets reward which agent gets for every step taken (default -1) + - `false_positive` - sets reward for a false positive action (default -5) - `actions` - defines the probability of success for every ActionType ```YAML env: random_seed: 'random' - scenario: 'scenario1' + scenario: 'two_networks_tiny' use_global_defender: False use_dynamic_addresses: False use_firewall: True save_trajectories: False rewards: - win: 100 + success: 100 step: -1 - loss: -10 + fail: -10 + false_positive: -5 actions: scan_network: - prob_success: 1.0 + prob_success: 1.0 find_services: - prob_success: 1.0 + prob_success: 1.0 exploit_service: - prob_success: 1.0 + prob_success: 1.0 find_data: - prob_success: 1.0 + prob_success: 1.0 exfiltrate_data: - prob_success: 1.0 + prob_success: 1.0 block_ip: - prob_success: 1.0 + prob_success: 1.0 ``` ### Definition of the network topology The network topology and rules are defined using a [CYST](https://pypi.org/project/cyst/) simulator configuration. Cyst defines a complex network configuration, and this environment does not use all Cyst features for now. CYST components currently used are: @@ -118,24 +158,25 @@ coordinator: Attacker: max_steps: 20 goal: - randomize_goal_every_episode: False - known_networks: [] - known_hosts: [] - controlled_hosts: [] - known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]} - known_data: {213.47.23.195: ["random"]} - known_blocks: {'all_routers': 'all_attackers'} + description: "Exfiltrate data from Samba server to remote C&C server." + is_any_part_of_goal_random: True + known_networks: [] + known_hosts: [] + controlled_hosts: [] + known_services: {} + known_data: {213.47.23.195: [[User1,DataFromServer1]]} + known_blocks: {} start_position: - known_networks: [] - known_hosts: [] - # The attacker must always at least control the CC if the goal is to exfiltrate there - # Example of fixing the starting point of the agent in a local host - controlled_hosts: [213.47.23.195, random] - # Services are defined as a target host where the service must be, and then a description in the form 'name, type, version, is_local' - known_services: {} - known_data: {} - known_blocks: {} + known_networks: [] + known_hosts: [] + # The attacker must always at least control the CC if the goal is to exfiltrate there + # Example of fixing the starting point of the agent in a local host + controlled_hosts: [213.47.23.195, random] + # Services are defined as a target host where the service must be, and then a description in the form 'name, type, version, is_local' + known_services: {} + known_data: {} + known_blocks: {} ``` ### Defender configuration @@ -164,4 +205,18 @@ Example of defender configuration: blocked_ips: {} known_blocks: {} ``` -As in other agents, the description is only a text for the agent, so it can know what is supposed to do to win. In the curent implementation, the *Defender* wins, if **NO ATTACKER** reaches their goal. \ No newline at end of file +As in other agents, the description is only a text for the agent, so it can know what is supposed to do to win. In the curent implementation, the *Defender* wins, if **NO ATTACKER** reaches their goal. + +### Trajectory storing and analysis + +The trajectory is a sequence of GameStates, Actions, and rewards in one run of a game. It contains the complete information of the actions played by the agent, the rewards observed and their effect on the state of the environment. + +Trajectories performed by the agents can be stored in a file using the following configuration (on the server side): + +```yaml +env: + save_trajectories: True +``` +!!! warning "Caution" + Trajectory files can grow very fast. It is recommended to use this feature on evaluation/testing runs only. By default, this feature is not enabled. + diff --git a/docs/configuration_manager.md b/docs/configuration_manager.md new file mode 100644 index 00000000..98c22770 --- /dev/null +++ b/docs/configuration_manager.md @@ -0,0 +1,11 @@ +## Configuration Manager + +Configuration manager is a component of the game coordinator that handles the configuration of the game. It is responsible for loading the configuration from the YAML file and providing it to the game coordinator. + +::: netsecgame.game.configuration_manager.ConfigurationManager + +## ConfigParser + +ConfigParser is a class that is responsible for parsing the YAML configuration file and providing it to the game coordinator. + +::: netsecgame.game.config_parser.ConfigParser \ No newline at end of file diff --git a/AIDojoCoordinator/docs/figures/scenarios/scenario_1.png b/docs/figures/scenarios/scenario_1.png similarity index 100% rename from AIDojoCoordinator/docs/figures/scenarios/scenario_1.png rename to docs/figures/scenarios/scenario_1.png diff --git a/AIDojoCoordinator/docs/figures/scenarios/scenario 1_small.png b/docs/figures/scenarios/scenario_1_small.png similarity index 100% rename from AIDojoCoordinator/docs/figures/scenarios/scenario 1_small.png rename to docs/figures/scenarios/scenario_1_small.png diff --git a/AIDojoCoordinator/docs/figures/scenarios/scenario_1_tiny.png b/docs/figures/scenarios/scenario_1_tiny.png similarity index 100% rename from AIDojoCoordinator/docs/figures/scenarios/scenario_1_tiny.png rename to docs/figures/scenarios/scenario_1_tiny.png diff --git a/AIDojoCoordinator/docs/figures/scenarios/three_nets.png b/docs/figures/scenarios/three_nets.png similarity index 100% rename from AIDojoCoordinator/docs/figures/scenarios/three_nets.png rename to docs/figures/scenarios/three_nets.png diff --git a/docs/game_components.md b/docs/game_components.md index 109af552..c54cdb5a 100644 --- a/docs/game_components.md +++ b/docs/game_components.md @@ -1,2 +1 @@ -# Game Components -::: AIDojoCoordinator.game_components \ No newline at end of file +::: netsecgame.game_components diff --git a/docs/game_coordinator.md b/docs/game_coordinator.md index 40b92a3b..2c639340 100644 --- a/docs/game_coordinator.md +++ b/docs/game_coordinator.md @@ -3,15 +3,26 @@ Coordinator is the centerpiece of the game orchestration. It provides an interfa In detail it handles: -1. World initialiazation +1. World initialization 2. Registration of new agents in the game 3. Agent-World communication (message verification and forwarding) 4. Recording (and storing) trajectories of agents (optional) -4. Detection of episode ends (either by reaching timout or agents reaching their respective goals) -5. Assigning rewards for each action and at the end of each episode -6. Removing agents from the game -7. Registering the GameReset requests and handelling the game resets. +5. Detection of episode ends (either by reaching timeout or agents reaching their respective goals) +6. Assigning rewards for each action and at the end of each episode +7. Removing agents from the game +8. Registering the GameReset requests and handling the game resets. To facilitate the communication the coordinator uses a TCP server to which agents connect. The communication is asynchronous and depends of the -::: AIDojoCoordinator.coordinator.AgentServer -::: AIDojoCoordinator.coordinator.GameCoordinator \ No newline at end of file + +## Connection to other game components +Coordinator, having the role of the middle man in all communication between the agent and the world uses several queues for message passing and handling. + +1. `Action queue` is a queue in which the agents submit their actions. It provides N:1 communication channel in which the coordinator receives the inputs. +2. `Answer queues` is a separate queue **per agent** in which the results of the actions are send to the agent. + +## Episode +The episode starts with sufficient amount of agents registering in the game. Each agent role has a maximum allowed number of steps defined in the task configuration. An episode ends if all agents reach the goal + + +::: netsecgame.game.coordinator.GameCoordinator +::: netsecgame.game.worlds.NetSecGame.NetSecGame \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index cf64a337..c9ab1a13 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,55 +1,60 @@ -# Network Security Game - -The NetSecGame (Network Security Game) is a framework for training and evaluation of AI agents in the network security tasks (both offensive and defensive). It is build with [CYST](https://pypi.org/project/cyst/) network simulator and enables rapid development and testing of AI agents in highly configurable scenarios. Examples of implemented agents can be seen in the submodule [NetSecGameAgents](https://github.com/stratosphereips/NetSecGameAgents/tree/main). +# NetSecGame +The NetSecGame (Network Security Game) is a framework for training and evaluation of AI agents in network security tasks (both offensive and defensive). It is built with [CYST](https://pypi.org/project/cyst/) network simulator and enables rapid development and testing of AI agents in highly configurable scenarios. Examples of implemented agents can be seen in the submodule [NetSecGameAgents](https://github.com/stratosphereips/NetSecGameAgents/tree/main). ## Installation Guide -It is recommended to install the NetSecGame in a virual environement: -### Python venv -1. +It is recommended to run the environment in the Docker container. The up-to-date image can be found in [Dockerhub](https://hub.docker.com/r/stratosphereips/netsecgame). +```bash +docker pull stratosphereips/netsecgame +``` +#### Building the image locally +Optionally, you can build the image locally with: +```bash +docker build -t netsecgame:local . +``` + +### Installing from source +In case you need to modify the envirment and run directly, we recommed to insall it in a virtual environemnt (Python vevn or Conda): +#### Python venv +1. Create new virtual environment ```bash -python -m venv +python -m venv ``` -2. +2. Activate newly created venv ```bash source /bin/activate ``` -### Conda -1. +#### Conda +1. Create new conda environment ```bash -conda create --name aidojo python==3.12.10 +conda create --name aidojo python==3.12 ``` -2. +2. Activate newly created conda env ```bash conda activate aidojo ``` -After the virtual environment is activated, install using pip: + +### After preparing virutual environment, install using pip: ```bash pip install -e . ``` -### With Docker -The NetSecGame can be run in a Docker container. You can build the image locally with: -```bash -docker build -t aidojo-nsg-coordinator:latest . -``` -or use the available image from [Dockerhub](https://hub.docker.com/r/stratosphereips/netsecgame). -```bash -docker pull stratosphereips/netsecgame -``` + ## Quick Start -A task configuration needs to be specified to start the NetSecGame (see [Configuration](configuration.md)). For the first step, the example task configuration is recommended: +A task configuration YAML file is required for starting the NetSecGame environment. For the first step, the example task configuration is recommended: + +### Example Configuration ```yaml # Example of the task configuration for NetSecGame # The objective of the Attacker in this task is to locate specific data # and exfiltrate it to a remote C&C server. -# The scenario starts AFTER initial breach of the local network +# The scenario starts AFTER the initial breach of the local network # (the attacker controls 1 local device + the remote C&C server). -coordinator: - agents: +coordinator: + agents: Attacker: # Configuration of 'Attacker' agents - max_steps: 25 - goal: + max_steps: 25 # timout set for the role `Attacker` + goal: # Definition of the goal state description: "Exfiltrate data from Samba server to remote C&C server." is_any_part_of_goal_random: True known_networks: [] @@ -58,10 +63,10 @@ coordinator: known_services: {} known_data: {213.47.23.195: [[User1,DataFromServer1]]} # winning condition known_blocks: {} - start_position: # Defined starting position of the attacker + start_position: # Definition of the starting state (keywords "random" and "all" can be used) known_networks: [] known_hosts: [] - controlled_hosts: [213.47.23.195, random] # + controlled_hosts: [213.47.23.195, random] # keyword 'random' will be replaced by randomly selected IP during initilization known_services: {} known_data: {} known_blocks: {} @@ -86,62 +91,89 @@ coordinator: blocked_ips: {} known_blocks: {} -env: +env: # Environment configuraion scenario: 'two_networks_tiny' # use the smallest topology for this example use_global_defender: False # Do not use global SIEM Defender use_dynamic_addresses: False # Do not randomize IP addresses use_firewall: True # Use firewall save_trajectories: False # Do not store trajectories - required_players: 1 + required_players: 1 # Minimal amount of agents requiered to start the game rewards: # Configurable reward function success: 100 step: -1 fail: -10 false_positive: -5 ``` - -The game can be started with: -```bash -python3 -m AIDojoCoordinator.worlds.NSEGameCoordinator \ - --task_config=./examples/example_config.yaml \ - --game_port=9000 -``` -Upon which the game server is created on `localhost:9000` to which the agents can connect to interact in the NetSecGame. -### Docker Container -When running in the Docker container, the NetSecGame can be started with: +### Starting the NetSecGame +With the configuration ready the environment can be started in selected port +#### In Docker container ```bash -docker run -it --rm \ - -v $(pwd)/examples/example_task_configuration.yaml:/aidojo/netsecenv_conf.yaml \ - -v $(pwd)/logs:/aidojo/logs \ +docker run -d --rm --name nsg-server\ + -v $(pwd)/examples/example_task_configuration.yaml:/netsecgame/netsecenv_conf.yaml \ + -v $(pwd)/logs:/netsecgame/logs \ -p 9000:9000 stratosphereips/netsecgame + --debug_level="INFO" ``` -optionally, you can set the logging level with `--debug_level=["DEBUG", "INFO", "WARNING", "CRITICAL"]` (defaul=`"INFO"`): +`--name nsg-server`: specifies the name of the container +`-v :/netsecgame/netsecenv_conf.yaml` : Mapping of the configuration file + +`-v $(pwd)/logs:/netsecgame/logs`: Mapping of the folder where logs are stored + +` -p :9000`: Mapping of the port in which the server runs + +`--debug_level` is an optional parameter to control the logging level `--debug_level=["DEBUG", "INFO", "WARNING", "CRITICAL"]` (defaul=`"INFO"`): +##### Running on Windows (with Docker desktop) +When running on Windows, Docker desktop is required. +```batch +docker run -d --rm --name netsecgame-server ^ + -p 9000:9000 ^ + -v "%cd%\examples\example_task_configuration.yaml:/netsecgame/netsecenv_conf.yaml" ^ + -v "%cd%\logs:/netsecgame/logs" ^ + stratosphereips/netsecgame:latest + --debug_level="INFO" +``` + +#### Locally +The environment can be started locally with from the root folder of the repository with following command: ```bash -docker run -it --rm \ - -v $(pwd)/examples/example_task_configuration.yaml:/aidojo/netsecenv_conf.yaml \ - -v $(pwd)/logs:/aidojo/logs \ - -p 9000:9000 stratosphereips/netsecgame \ - --debug_level="WARNING" +python3 -m netsecgame.game.worlds.NetSecGame \ + --task_config=./examples/example_task_configuration.yaml \ + --game_port=9000 + --debug_level="INFO" ``` +Upon which the game server is created on `localhost:9000` to which the agents can connect to interact in the NetSecGame. -## Documentation -The NetSecGame environment has several components in the following files: +### Components of the NetSecGame Environment +The NetSecGame has several components in the following files: ``` -├── AIDojoGameCoordinator/ -| ├── game_coordinator.py -| ├── game_components.py -| ├── global_defender.py -| ├── worlds/ -| ├── NSGCoordinator.py -| ├── NSGRealWorldCoordinator.py -| ├── CYSTCoordinator.py -| ├── scenarios/ +├── netsecgame/ +| ├── agents/ +| ├── base_agent.py # Basic agent class. Defines the API for agent-server communication +| ├── game/ +| ├── scenarios/ +| ├── three_net_scenario.py +| ├── two_nets.py +| ├── two_nets_tiny.py +| ├── two_nets_small.py +| ├── one_net.py +| ├── worlds/ +| ├── NetSecGame.py # (NSG) basic simulation +| ├── RealWorldNetSecGame.py # Extension of `NSG` - runs actions in the *network of the host computer* +| ├── CYSTCoordinator.py # Extension of `NSG` - runs simulation in CYST engine. +| ├── WhiteBoxNetSecGame.py # Extension of `NSG` - provides agents with full list of actions upon registration. +| ├── config_parser.py # NSG task configuration parser +| ├── configuration_manager.py # Manages the loading and access of game configuration. +| ├── coordinator.py # Core game server. Not to be run as stand-alone world (see worlds/) +| ├── agent_server.py # Class used for serving the agents when connecting to the game run by the GameCoordinator. +| ├── global_defender.py # Stochastic (non-agentic defender) +| ├── game_components.py # contains basic building blocks of the environment | ├── utils/ | ├── utils.py -| ├── log_parser.py +| ├── trajectory_recorder.py +| ├── trajectory_analysis.py +| ├── aidojo_log_colorizer.py | ├── gamaplay_graphs.py -| ├── actions_parser.py ``` Some compoments are described in detail in following sections: diff --git a/mkdocs.yml b/mkdocs.yml index e5fdccab..58303545 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,6 +9,8 @@ nav: - API Reference: - game_components.md - game_coordinator.md + - agent_server.md + - configuration_manager.md plugins: - mkdocstrings: @@ -26,6 +28,8 @@ plugins: markdown_extensions: - pymdownx.arithmatex - pymdownx.superfences + - pymdownx.highlight + - admonition extra_javascript: - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js \ No newline at end of file diff --git a/netsecgame/__init__.py b/netsecgame/__init__.py new file mode 100644 index 00000000..3f4df292 --- /dev/null +++ b/netsecgame/__init__.py @@ -0,0 +1,71 @@ +# add imports so that they are available when importing the package NetSecGame +# e.g., from NetSecGame import GameState + +__version__ = "0.1.0" + +# Game components +from .game_components import ( + Action, + ActionType, + AgentInfo, + AgentRole, + Data, + GameState, + GameStatus, + IP, + Network, + Observation, + ProtocolConfig, + Service +) +# Base agent +from .agents.base_agent import BaseAgent + +# Selected util functions +from .utils.utils import ( + get_file_hash, + state_as_ordered_string, + store_trajectories_to_jsonl, + read_trajectories_from_jsonl, + observation_as_dict, + observation_to_str, + observation_from_str, + observation_from_dict, + generate_valid_actions +) + +# Trajectory Recorder +from .utils.trajectory_recorder import TrajectoryRecorder + +# Define the public API of the package +__all__ = [ + # Metadata + "__version__", + # Game components + "Action", + "ActionType", + "AgentInfo", + "AgentRole", + "Data", + "GameState", + "GameStatus", + "IP", + "Network", + "Observation", + "ProtocolConfig", + "Service", + # Base agent + "BaseAgent", + # Utils + "get_file_hash", + "state_as_ordered_string", + "store_trajectories_to_jsonl", + "read_trajectories_from_jsonl", + "observation_as_dict", + "observation_to_str", + "observation_from_str", + "observation_from_dict", + "generate_valid_actions", + # Trajectory recorder + "TrajectoryRecorder" +] \ No newline at end of file diff --git a/AIDojoCoordinator/__init__.py b/netsecgame/agents/__init__.py similarity index 100% rename from AIDojoCoordinator/__init__.py rename to netsecgame/agents/__init__.py diff --git a/netsecgame/agents/base_agent.py b/netsecgame/agents/base_agent.py new file mode 100644 index 00000000..0c5ad189 --- /dev/null +++ b/netsecgame/agents/base_agent.py @@ -0,0 +1,193 @@ +# Author: Ondrej Lukas, ondrej.lukas@aic.cvut.cz +# Basic agent class that is to be extended in each agent classes +import logging +import socket +import json +from abc import ABC + +from netsecgame.game_components import Action, GameState, Observation, ActionType, GameStatus, AgentInfo, ProtocolConfig, AgentRole + +class BaseAgent(ABC): + """ + Author: Ondrej Lukas, ondrej.lukas@aic.cvut.cz + Basic agent for the network based NetSecGame environment. Implemenets communication with the game server. + """ + + def __init__(self, host, port, role:str)->None: + self._connection_details = (host, port) + self._logger = logging.getLogger(self.__class__.__name__) + self._role = role + try: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((host, port)) + except socket.error as e: + self._logger.error(f"Socket error: {e}") + self.sock = None + self._logger.info("Agent created") + + def __del__(self): + "In case the extending class did not close the connection, terminate the socket when the object is deleted." + if self._socket: + try: + self._socket.close() + self._logger.info("Socket closed") + except socket.error as e: + print(f"Error closing socket: {e}") + + def terminate_connection(self)->None: + """Method for graceful termination of connection. Should be used by any class extending the BaseAgent.""" + if self._socket: + try: + self._socket.close() + self._socket = None + self._logger.info("Socket closed") + except socket.error as e: + print(f"Error closing socket: {e}") + @property + def socket(self)->socket.socket | None: + return self._socket + + @property + def role(self)->str: + return self._role + + @property + def logger(self)->logging.Logger: + return self._logger + + def make_step(self, action: Action) -> Observation | None: + """ + Executes a single step in the environment by sending the agent's action to the server and receiving the resulting observation. + + Args: + action (Action): The action to be performed by the agent. + + Returns: + Observation: The new observation received from the server, containing the updated game state, reward, end flag, and additional info. + None: If no observation is received from the server. + + Raises: + Any exceptions raised by the `communicate` method are propagated. + """ + _, observation_dict, _ = self.communicate(action) + if observation_dict: + return Observation(GameState.from_dict(observation_dict["state"]), observation_dict["reward"], observation_dict["end"], observation_dict["info"]) + else: + return None + + def communicate(self, data:Action)-> tuple: + """ + Exchanges data with the server and returns the server's response. + This method sends an `Action` object to the server and waits for a response. + The response is expected to be a JSON-encoded string containing status, observation, and message fields. + The method returns a tuple containing the parsed status, observation, and message. + Args: + data (Action): The action to send to the server. Must be an instance of `Action`. + Returns: + tuple: A tuple containing: + - status (GameStatus): The status object parsed from the server response. + - observation (dict): The observation data from the server. + - message (str or None): An optional message from the server. + Raises: + ValueError: If `data` is not of type `Action`. + ConnectionError: If the server response is incomplete or missing the end-of-message marker. + Exception: If there is an error sending data to the server. + """ + + def _send_data(socket, msg:str)->None: + try: + self._logger.debug(f'Sending: {msg}') + socket.sendall(msg.encode()) + except Exception as e: + self._logger.error(f'Exception in _send_data(): {e}') + raise e + + def _receive_data(socket)->tuple: + """ + Receive data from server + """ + # Receive data from the server + data = b"" # Initialize an empty byte string + + while True: + chunk = socket.recv(ProtocolConfig.BUFFER_SIZE) # Receive a chunk + if not chunk: # If no more data, break (connection closed) + break + data += chunk + if ProtocolConfig.END_OF_MESSAGE in data: # Check if EOF marker is present + break + if ProtocolConfig.END_OF_MESSAGE not in data: + raise ConnectionError("Unfinished connection.") + data = data.replace(ProtocolConfig.END_OF_MESSAGE, b"") # Remove EOF marker + data = data.decode() + self._logger.debug(f"Data received from env: {data}") + # extract data from string representation + data_dict = json.loads(data) + # Add default values if dict keys are missing + status = data_dict["status"] if "status" in data_dict else "" + observation = data_dict["observation"] if "observation" in data_dict else {} + message = data_dict["message"] if "message" in data_dict else None + + return GameStatus.from_string(str(status)), observation, message + + if isinstance(data, Action): + data = data.to_json() + else: + raise ValueError("Incorrect data type! Data should be ONLY of type Action") + + _send_data(self._socket, data) + return _receive_data(self._socket) + + def register(self)->Observation | None: + """ + Method for registering agent to the game server. + Classname is used as agent name and the role is based on the 'role' argument. + Returns initial observation if registration was successful, None otherwise. + + Args: + role (str): Role of the agent, either 'attacker' or 'defender'. + Returns: + Observation: Initial observation if registration was successful, None otherwise. + """ + try: + self._logger.info(f'Registering agent as {self.role}') + status, observation_dict, message = self.communicate(Action(ActionType.JoinGame, + parameters={"agent_info":AgentInfo(self.__class__.__name__,self.role.value)})) + if status is GameStatus.CREATED: + self._logger.info(f"\tRegistration successful! {message}") + return Observation(GameState.from_dict(observation_dict["state"]), observation_dict["reward"], observation_dict["end"], message) + else: + self._logger.error(f'\tRegistration failed! (status: {status}, msg:{message}') + return None + except Exception as e: + self._logger.error(f'Exception in register(): {e}') + + def request_game_reset(self, request_trajectory=False, randomize_topology=True, randomize_topology_seed=None) -> Observation|None: + """ + Requests a game reset from the server. Optionally requests a trajectory and/or topology randomization. + Args: + request_trajectory (bool): If True, requests the server to provide a trajectory of the last episode. + randomize_topology (bool): If True, requests the server to randomize the network topology for the next episode. Defaults to True. + randomize_topology_seed (int): If provided, requests the server to use this seed for randomizing the network topology. Defaults to None. + Returns: + Observation: The initial observation after the reset if successful, None otherwise. + """ + self._logger.debug("Requesting game reset") + status, observation_dict, message = self.communicate(Action(ActionType.ResetGame, parameters={"request_trajectory": request_trajectory, "randomize_topology": randomize_topology})) + if status: + self._logger.debug('\tReset successful') + return Observation(GameState.from_dict(observation_dict["state"]), observation_dict["reward"], observation_dict["end"], message) + else: + self._logger.error(f'\rReset failed! (status: {status}, msg:{message}') + return None + +if __name__ == "__main__": + # Example usage of BaseAgent + GAME_PORT = 5000 # Change to the appropriate port + agent = BaseAgent("localhost", GAME_PORT, AgentRole.Attacker) + # Register the agent + observation = agent.register() + if observation: + print("Initial Observation:", observation) + # Gracefully terminate the connection + agent.terminate_connection() \ No newline at end of file diff --git a/netsecgame/game/__init__.py b/netsecgame/game/__init__.py new file mode 100644 index 00000000..88c7adad --- /dev/null +++ b/netsecgame/game/__init__.py @@ -0,0 +1,16 @@ +try: + # Attempt to import server-specific dependencies + # disable ruff error F401 for unused imports (used for dependency checking) + import cyst # noqa: F401 + import aiohttp # noqa: F401 + import faker # noqa: F401 + import numpy as np # noqa: F401 + import requests # noqa: F401 +except ImportError as e: + # If any server-specific dependency is missing, raise an informative error + # Surpress the context of the original ImportError + raise ImportError( + f"Failed to import 'netsecgame.game'. This module requires server dependencies.\n" + f"Missing dependency: {e.name}\n" + f"Please install them using: pip install 'netsecgame[server]'" + ) from None \ No newline at end of file diff --git a/netsecgame/game/agent_server.py b/netsecgame/game/agent_server.py new file mode 100644 index 00000000..d26b2778 --- /dev/null +++ b/netsecgame/game/agent_server.py @@ -0,0 +1,124 @@ +import logging +import asyncio +from netsecgame.game_components import Action, ActionType, ProtocolConfig + +class AgentServer(asyncio.Protocol): + """ + Class used for serving the agents when connecting to the game run by the GameCoordinator. + + Attributes: + actions_queue (asyncio.Queue): Queue for actions from agents. + answers_queues (dict): Mapping of agent addresses to their response queues. + max_connections (int): Maximum allowed concurrent agent connections. + current_connections (int): Current number of connected agents. + logger (logging.Logger): Logger for the AgentServer. + """ + def __init__(self, actions_queue, agent_response_queues, max_connections): + """ + Initialize the AgentServer. + + Args: + actions_queue (asyncio.Queue): Queue for actions from agents. + agent_response_queues (dict): Mapping of agent addresses to their response queues. + max_connections (int): Maximum allowed concurrent agent connections. + """ + self.actions_queue = actions_queue + self.answers_queues = agent_response_queues + self.max_connections = max_connections + self.current_connections = 0 + self.logger = logging.getLogger("AgentServer") + + async def handle_agent_quit(self, peername:tuple): + """ + Helper function to handle agent disconnection. + + Args: + peername (tuple): The address of the disconnecting agent. + """ + # Send a quit message to the Coordinator + self.logger.info(f"\tHandling agent quit for {peername}.") + quit_message = Action(ActionType.QuitGame, parameters={}).to_json() + await self.actions_queue.put((peername, quit_message)) + + async def handle_new_agent(self, reader, writer): + """ + Handle a new agent connection. + + Args: + reader (asyncio.StreamReader): Stream reader for the agent. + writer (asyncio.StreamWriter): Stream writer for the agent. + """ + # get the peername of the writer + peername = writer.get_extra_info("peername") + queue_created = False + try: + self.logger.info(f"New connection from {peername}") + # Check if the maximum number of connections has been reached + if self.current_connections < self.max_connections: + # increment the count of current connections + self.current_connections += 1 + self.logger.info(f"New agent connected: {peername}. Current connections: {self.current_connections}") + # Ensure a queue exists for this agent + if peername not in self.answers_queues: + self.answers_queues[peername] = asyncio.Queue(maxsize=2) + queue_created = True + self.logger.info(f"Created queue for agent {peername}") + # Handle the new agent + while True: + # Step 1: Read data from the agent + data = await reader.read(ProtocolConfig.BUFFER_SIZE) + if not data: + self.logger.info(f"Agent {peername} disconnected.") + await self.handle_agent_quit(peername) + break + + raw_message = data.decode().strip() + self.logger.debug(f"Handler received from {peername}: {raw_message}") + + # Step 2: Forward the message to the Coordinator + await self.actions_queue.put((peername, raw_message)) + + # Step 3: Get a matching response from the answers queue + response_queue = self.answers_queues[peername] + response = await response_queue.get() + self.logger.info(f"Sending response to agent {peername}: {response}") + + # Step 4: Send the response to the agent + response = str(response).encode() + ProtocolConfig.END_OF_MESSAGE + writer.write(response) + await writer.drain() + else: + self.logger.warning(f"Queue for agent {peername} already exists. Closing connection.") + else: + self.logger.info(f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}") + except ConnectionResetError: + self.logger.warning(f"Connection reset by {peername}") + await self.handle_agent_quit(peername) + except asyncio.CancelledError: + self.logger.debug("Connection handling cancelled.") + raise # Ensure the exception propagates + except Exception as e: + self.logger.error(f"Unexpected error with client {peername}: {e}") + raise + finally: + try: + if peername in self.answers_queues: + # If the queue was created, remove it + if queue_created: + self.answers_queues.pop(peername) + self.logger.info(f"Removed queue for agent {peername}") + self.current_connections = max(0, self.current_connections - 1) + writer.close() + await writer.wait_closed() + except Exception: + # swallow exceptions on close to avoid crash on cleanup + pass + async def __call__(self, reader, writer): + """ + Allow the server instance to be called as a coroutine. + + Args: + reader (asyncio.StreamReader): Stream reader for the agent. + writer (asyncio.StreamWriter): Stream writer for the agent. + """ + await self.handle_new_agent(reader, writer) diff --git a/AIDojoCoordinator/utils/utils.py b/netsecgame/game/config_parser.py similarity index 65% rename from AIDojoCoordinator/utils/utils.py rename to netsecgame/game/config_parser.py index face37cf..8aace5ec 100644 --- a/AIDojoCoordinator/utils/utils.py +++ b/netsecgame/game/config_parser.py @@ -1,144 +1,28 @@ -# Utility functions for then env and for the agents +# Config parser for NetSecGame Coordinator # Author: Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz # Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz import yaml # This is used so the agent can see the environment and game components import importlib -from AIDojoCoordinator.game_components import IP, Data, Network, Service, GameState, Action, Observation, ActionType +from netsecgame.game_components import IP, Data, Network, Service import netaddr import logging -import csv -import os -import jsonlines from random import randint -import json -import hashlib -from cyst.api.configuration.network.node import NodeConfig from typing import Optional -def get_file_hash(filepath, hash_func='sha256', chunk_size=4096): - """ - Computes hash of a given file. - """ - hash_algorithm = hashlib.new(hash_func) - with open(filepath, 'rb') as file: - chunk = file.read(chunk_size) - while chunk: - hash_algorithm.update(chunk) - chunk = file.read(chunk_size) - return hash_algorithm.hexdigest() - -def get_str_hash(string, hash_func='sha256', chunk_size=4096): - """ - Computes hash of a given file. - """ - hash_algorithm = hashlib.new(hash_func) - hash_algorithm.update(string.encode('utf-8')) - return hash_algorithm.hexdigest() - -def read_replay_buffer_from_csv(csvfile:str)->list: - """ - Function to read steps from a CSV file - and restore the objects in the replay buffer. - - expected colums in the csv: - state_t0, action_t0, reward_t1, state_t1, done_t1 - """ - buffer = [] - try: - with open(csvfile, 'r') as f_object: - csv_reader = csv.reader(f_object, delimiter=';') - for [s_t, a_t, r, s_t1 , done] in csv_reader: - buffer.append((GameState.from_json(s_t), Action.from_json(a_t), r, GameState.from_json(s_t1), done)) - except FileNotFoundError: - # There was no buffer - pass - return buffer - -def store_replay_buffer_in_csv(replay_buffer:list, filename:str, delimiter:str=";")->None: - """ - Function to store steps from a replay buffer in CSV file. - Expected format of replay buffer items: - (state_t0:GameState, action_t0:Action, reward_t1:float, state_t1:GameState, done_t1:bool) - """ - with open(filename, 'a') as f_object: - writer_object = csv.writer(f_object, delimiter=delimiter) - for (s_t, a_t, r, s_t1, done) in replay_buffer: - writer_object.writerow([s_t.as_json(), a_t.as_json(), r, s_t1.as_json(), done]) - -def state_as_ordered_string(state:GameState)->str: - ret = "" - ret += f"nets:[{','.join([str(x) for x in sorted(state.known_networks)])}]," - ret += f"hosts:[{','.join([str(x) for x in sorted(state.known_hosts)])}]," - ret += f"controlled:[{','.join([str(x) for x in sorted(state.controlled_hosts)])}]," - ret += "services:{" - for host in sorted(state.known_services.keys()): - ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_services[host])])}]" - ret += "},data:{" - for host in sorted(state.known_data.keys()): - ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_data[host])])}]" - ret += "}, blocks:{" - for host in sorted(state.known_blocks.keys()): - ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_blocks[host])])}]" - ret += "}" - return ret - -def observation_to_str(observation:Observation)-> str: - """ - Generates JSON string representation of a given Observation object. - """ - state_str = observation.state.as_json() - observation_dict = { - 'state': state_str, - 'reward': observation.reward, - 'end': observation.end, - 'info': dict(observation.info) - } - try: - observation_str = json.dumps(observation_dict) - return observation_str - except Exception as e: - print(f"Error in encoding observation '{observation}' to JSON string: {e}") - raise e - -def observation_as_dict(observation:Observation)->dict: - """ - Generates dict string representation of a given Observation object. - """ - observation_dict = { - 'state': observation.state.as_dict, - 'reward': observation.reward, - 'end': observation.end, - 'info': observation.info - } - return observation_dict - -def parse_log_content(log_content:str)->Optional[list]: - try: - logs = [] - data = json.loads(log_content) - for item in data: - ip = IP(item["source_host"]) - action_type = ActionType.from_string(item["action_type"]) - logs.append({"source_host":ip, "action_type":action_type}) - return logs - except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") - return None - except TypeError as e: - print(f"Error decoding JSON: {e}") - return None - class ConfigParser(): """ - Class to deal with the configuration file + Class to deal with the configuration file of NetSecGame Coordinator + Args: + task_config_file (str|None): Path to the configuration file + config_dict (dict|None): Dictionary with configuration data """ - def __init__(self, task_config_file:str=None, config_dict:dict=None): + def __init__(self, task_config_file:str|None=None, config_dict:dict|None=None): """ Initializes the configuration parser. Required either path to a confgiuration file or a dict with configuraitons. """ - self.logger = logging.getLogger('configparser') + self.logger = logging.getLogger('ConfigParser') if task_config_file: self.read_config_file(task_config_file) elif config_dict: @@ -444,40 +328,40 @@ def get_rewards(self, reward_names:list, default_value=0)->dict: rewards[name] = default_value return rewards - def get_use_dynamic_addresses(self)->bool: + def get_use_dynamic_addresses(self, default_value: bool = False)->bool: """ Reads if the IP and Network addresses should be dynamically changed. """ try: use_dynamic_addresses = self.config['env']['use_dynamic_addresses'] except KeyError: - use_dynamic_addresses = False + use_dynamic_addresses = default_value return bool(use_dynamic_addresses) - def get_store_trajectories(self): + def get_store_trajectories(self, default_value: bool = False): """ Read if the replay buffer should be stored in file """ try: - store_rb = self.config['env']['save_trajectories'] + store_trajectories = self.config['env']['save_trajectories'] except KeyError: # Option is not in the configuration - default to FALSE - store_rb = False - return store_rb + store_trajectories = default_value + return store_trajectories def get_scenario(self): """ Get the scenario config objects based on the configuration. Only import objects that are selected via importlib. """ allowed_names = { - "scenario1" : "AIDojoCoordinator.scenarios.scenario_configuration", - "scenario1_small" : "AIDojoCoordinator.scenarios.smaller_scenario_configuration", - "scenario1_tiny" : "AIDojoCoordinator.scenarios.tiny_scenario_configuration", - "one_network": "AIDojoCoordinator.scenarios.one_net", - "three_net_scenario": "AIDojoCoordinator.scenarios.three_net_scenario", - "two_networks": "AIDojoCoordinator.scenarios.two_nets", # same as scenario1 - "two_networks_small": "AIDojoCoordinator.scenarios.two_nets_small", # same as scenario1_small - "two_networks_tiny": "AIDojoCoordinator.scenarios.two_nets_tiny", # same as scenario1_small + "scenario1" : "netsecgame.game.scenarios.scenario_configuration", + "scenario1_small" : "netsecgame.game.scenarios.smaller_scenario_configuration", + "scenario1_tiny" : "netsecgame.game.scenarios.tiny_scenario_configuration", + "one_network": "netsecgame.game.scenarios.one_net", + "three_net_scenario": "netsecgame.game.scenarios.three_net_scenario", + "two_networks": "netsecgame.game.scenarios.two_nets", # same as scenario1 + "two_networks_small": "netsecgame.game.scenarios.two_nets_small", # same as scenario1_small + "two_networks_tiny": "netsecgame.game.scenarios.two_nets_tiny", # same as scenario1_small } scenario_name = self.config['env']['scenario'] @@ -498,7 +382,7 @@ def get_seed(self, whom): seed = randint(0,100) return seed - def get_randomize_goal_every_episode(self) -> bool: + def get_randomize_goal_every_episode(self, default_value: bool = False) -> bool: """ Get if the randomization should be done only once or at the beginning of every episode """ @@ -506,89 +390,32 @@ def get_randomize_goal_every_episode(self) -> bool: randomize_goal_every_episode = self.config["coordinator"]["agents"]["attackers"]["goal"]["is_any_part_of_goal_random"] except KeyError: # Option is not in the configuration - default to FALSE - randomize_goal_every_episode = False + randomize_goal_every_episode = default_value return randomize_goal_every_episode - def get_use_firewall(self)->bool: + def get_use_firewall(self, default_value: bool = False)->bool: """ Retrieves if the firewall functionality is allowed for netsecgame. Default: False """ try: - use_firewall = self.config['env']['use_firewall'] + use_firewall = self.config['env']['use_firewall'] except KeyError: - use_firewall = False + use_firewall = default_value return use_firewall - def get_use_global_defender(self)->bool: + def get_use_global_defender(self, default_value: bool = False)->bool: try: use_global_defender = self.config['env']['use_global_defender'] except KeyError: - use_global_defender = False + use_global_defender = default_value return use_global_defender - def get_required_num_players(self)->int: + def get_required_num_players(self, default_value: int = 1)->int: try: required_players = int(self.config['env']['required_players']) except KeyError: - required_players = 1 + required_players = default_value except ValueError: - required_players = 1 - return required_players - -def get_logging_level(debug_level): - """ - Configure logging level based on the provided debug_level string. - """ - log_levels = { - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL - } - - level = log_levels.get(debug_level.upper(), logging.ERROR) - return level - -def get_starting_position_from_cyst_config(cyst_objects): - starting_positions = {} - for obj in cyst_objects: - if isinstance(obj, NodeConfig): - for active_service in obj.active_services: - if active_service.type == "netsecenv_agent": - print(f"starting processing {obj.id}.{active_service.name}") - hosts = set() - networks = set() - for interface in obj.interfaces: - hosts.add(IP(str(interface.ip))) - net_ip, net_mask = str(interface.net).split("/") - networks.add(Network(net_ip,int(net_mask))) - starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks} - return starting_positions - -def store_trajectories_to_jsonl(trajectories:list, dir:str, filename:str)->None: - """ - Store trajectories to a JSONL file. - Args: - trajectories (list): List of trajectory data to store. - dir (str): Directory where the file will be stored. - filename (str): Name of the file (without extension). - """ - # make sure the directory exists - if not os.path.exists(dir): - os.makedirs(dir) - # construct the full file name - filename = os.path.join(dir, f"{filename.rstrip('jsonl')}.jsonl") - # store the trajectories - with jsonlines.open(filename, "a") as writer: - writer.write(trajectories) - -if __name__ == "__main__": - state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, - known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, - known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}}, - known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")}, - IP("192.168.1.2"):{Data("McGiver", "data2")}}) - - print(state_as_ordered_string(state)) \ No newline at end of file + required_players = default_value + return required_players \ No newline at end of file diff --git a/netsecgame/game/configuration_manager.py b/netsecgame/game/configuration_manager.py new file mode 100644 index 00000000..4a5ff752 --- /dev/null +++ b/netsecgame/game/configuration_manager.py @@ -0,0 +1,204 @@ +import logging +from typing import Optional, Dict, Any, List +from aiohttp import ClientSession + +from netsecgame.game.config_parser import ConfigParser +from netsecgame.utils.utils import get_str_hash +from cyst.api.environment.environment import Environment +from netsecgame.game_components import AgentRole + +class ConfigurationManager: + """ + Manages the loading and access of game configuration. + + Handles fetching configuration from efficient sources (local file or remote service) + and provides structured access to configuration data. + """ + + def __init__(self, task_config_file: Optional[str] = None, service_host: Optional[str] = None, service_port: Optional[int] = None): + self.logger = logging.getLogger("ConfigurationManager") + self._task_config_file = task_config_file + self._service_host = service_host + self._service_port = service_port + + self._parser: Optional[ConfigParser] = None + self._cyst_objects = None + self._config_file_hash: Optional[str] = None + + # Cache for parsed values + self._starting_positions: Dict[str, Any] = {} + self._win_conditions: Dict[str, Any] = {} + self._goal_descriptions: Dict[str, str] = {} + self._max_steps: Dict[str, Optional[int]] = {} + + async def load(self) -> None: + """ + Determines the source and loads the configuration. + Prioritizes remote service if configured, otherwise falls back to local file. + """ + if self._service_host and self._service_port: + self.logger.info(f"Fetching task configuration from {self._service_host}:{self._service_port}") + await self._fetch_remote_configuration() + elif self._task_config_file: + self.logger.info(f"Loading task configuration from file: {self._task_config_file}") + self._load_local_configuration() + else: + raise ValueError("Task configuration source not specified (neither file nor service)") + + async def _fetch_remote_configuration(self) -> None: + """Fetches initialization objects from the remote service.""" + url = f"http://{self._service_host}:{self._service_port}/cyst_init_objects" + async with ClientSession() as session: + try: + async with session.get(url) as response: + if response.status == 200: + config_data = await response.json() + self.logger.debug(f"Received config data: {config_data}") + + # Initialize CYST environment + env = Environment.create() + self._config_file_hash = get_str_hash(config_data) + self._cyst_objects = env.configuration.general.load_configuration(config_data) + self.logger.debug(f"Initialization objects received: {self._cyst_objects}") + + # Initialize parser with the fetched dict (assuming it contains task_configuration or similar structure) + # Note: The original coordinator code for remote fetch commented out creating ConfigParser: + # #self.task_config = ConfigParser(config_dict=response["task_configuration"]) + # usage of self.task_config in original code fell back to loading from file even if remote fetch happened? + # "Temporary fix" comment in original code suggests fallback. + # For this implementation, we should try to use the fetched config if possible. + # If the API returns the same structure as the YAML file, we can pass it to ConfigParser(config_dict=...) + # If not, we might need to rely on the file as the original code did for the parser part. + + # Let's assume for now we try to use the dictionary if available, otherwise fallback logic might be needed + # derived from how the response is structured. + # Looking at original code: response seems to be the full config. + self._parser = ConfigParser(config_dict=config_data) + + else: + self.logger.error(f"Failed to fetch initialization objects. Status: {response.status}") + raise RuntimeError(f"Remote configuration fetch failed with status {response.status}") + except Exception as e: + self.logger.error(f"Error fetching initialization objects: {e}") + # Fallback to local file if remote fails? The original code did: + # self.task_config = ConfigParser(self._task_config_file) + # We can implement similar fallback behavior here if desired, or just raise. + if self._task_config_file: + self.logger.warning("Falling back to local configuration file.") + self._load_local_configuration() + else: + raise e + + def _load_local_configuration(self) -> None: + """Loads configuration from the local file.""" + self._parser = ConfigParser(task_config_file=self._task_config_file) + self._cyst_objects = self._parser.get_scenario() + # Original code does str(self._cyst_objects) for hash + self._config_file_hash = get_str_hash(str(self._cyst_objects)) + + # ------------------------------------------------------------------------- + # Accessors + # ------------------------------------------------------------------------- + + def get_cyst_objects(self): + return self._cyst_objects + + def get_config_hash(self) -> Optional[str]: + return self._config_file_hash + + def get_starting_position(self, role: str) -> dict: + """Returns the starting position configuration for a specific role.""" + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_start_position(agent_role=role) + + def get_win_conditions(self, role: str) -> dict: + """Returns the win conditions for a specific role.""" + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_win_conditions(agent_role=role) + + def get_goal_description(self, role: str) -> str: + """Returns the goal description for a specific role.""" + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_goal_description(agent_role=role) + + def get_max_steps(self, role: str) -> Optional[int]: + """Returns the max steps for a specific role.""" + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_max_steps(role) + + def get_rewards(self, reward_names: List[str] = ["step", "success", "fail", "false_positive"], default_value: int = 0) -> dict: + """Returns the rewards configuration.""" + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_rewards(reward_names, default_value) + + def get_use_dynamic_ips(self, default_value: bool = False) -> bool: + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_use_dynamic_addresses(default_value) + + def get_use_global_defender(self, default_value: bool = False) -> bool: + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_use_global_defender(default_value) + + def get_required_num_players(self, default_value: int = 1) -> int: + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_required_num_players(default_value) + + def get_use_firewall(self, default_value: bool = True) -> bool: + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_use_firewall(default_value) + + def get_all_starting_positions(self) -> Dict[str, Any]: + """Returns starting positions for all roles.""" + starting_positions = {} + for agent_role in AgentRole: + try: + starting_positions[agent_role] = self.get_starting_position(role=agent_role) + self.logger.info(f"Starting position for role '{agent_role}': {starting_positions[agent_role]}") + except KeyError: + starting_positions[agent_role] = {} + return starting_positions + + def get_all_win_conditions(self) -> Dict[str, Any]: + """Returns win conditions for all roles.""" + win_conditions = {} + for agent_role in AgentRole: + try: + win_conditions[agent_role] = self.get_win_conditions(role=agent_role) + except KeyError: + win_conditions[agent_role] = {} + self.logger.info(f"Win condition for role '{agent_role}': {win_conditions[agent_role]}") + return win_conditions + + def get_all_goal_descriptions(self) -> Dict[str, str]: + """Returns goal descriptions for all roles.""" + goal_descriptions = {} + for agent_role in AgentRole: + try: + goal_descriptions[agent_role] = self.get_goal_description(role=agent_role) + except KeyError: + goal_descriptions[agent_role] = "" + self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}") + return goal_descriptions + + def get_all_max_steps(self) -> Dict[str, Optional[int]]: + """Returns max steps for all roles.""" + # Using self.get_max_steps might raise RuntimeError if checks are there, + # but simpler to just call parser directly or the single accessor since we are inside the class. + # However, the single accessor has the check. + # But wait, self.get_max_steps(role) does `self._parser.get_max_steps(role)` already. + # Iterating over AgentRole is correct. + return {role: self.get_max_steps(role) for role in AgentRole} + + def get_store_trajectories(self, default_value: bool = False) -> bool: + if not self._parser: + raise RuntimeError("Configuration not loaded.") + return self._parser.get_store_trajectories(default_value) diff --git a/AIDojoCoordinator/coordinator.py b/netsecgame/game/coordinator.py similarity index 70% rename from AIDojoCoordinator/coordinator.py rename to netsecgame/game/coordinator.py index 40e771c2..904e3444 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/netsecgame/game/coordinator.py @@ -2,135 +2,29 @@ import json import asyncio from datetime import datetime +from typing import Optional import signal - -from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, ProtocolConfig -from AIDojoCoordinator.global_defender import GlobalDefender -from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser, store_trajectories_to_jsonl import os -from aiohttp import ClientSession -from cyst.api.environment.environment import Environment - -class AgentServer(asyncio.Protocol): - """ - Class used for serving the agents when connecting to the game run by the GameCoordinator. - Attributes: - actions_queue (asyncio.Queue): Queue for actions from agents. - answers_queues (dict): Mapping of agent addresses to their response queues. - max_connections (int): Maximum allowed concurrent agent connections. - current_connections (int): Current number of connected agents. - logger (logging.Logger): Logger for the AgentServer. - """ - def __init__(self, actions_queue, agent_response_queues, max_connections): - """ - Initialize the AgentServer. +from netsecgame.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, AgentRole +from netsecgame.game.global_defender import GlobalDefender +from netsecgame.utils.utils import observation_as_dict,store_trajectories_to_jsonl +from netsecgame.game.agent_server import AgentServer +from netsecgame.game.configuration_manager import ConfigurationManager - Args: - actions_queue (asyncio.Queue): Queue for actions from agents. - agent_response_queues (dict): Mapping of agent addresses to their response queues. - max_connections (int): Maximum allowed concurrent agent connections. - """ - self.actions_queue = actions_queue - self.answers_queues = agent_response_queues - self.max_connections = max_connections - self.current_connections = 0 - self.logger = logging.getLogger("AIDojo-AgentServer") - - async def handle_agent_quit(self, peername:tuple): - """ - Helper function to handle agent disconnection. - Args: - peername (tuple): The address of the disconnecting agent. - """ - # Send a quit message to the Coordinator - self.logger.info(f"\tHandling agent quit for {peername}.") - quit_message = Action(ActionType.QuitGame, parameters={}).to_json() - await self.actions_queue.put((peername, quit_message)) - - async def handle_new_agent(self, reader, writer): - """ - Handle a new agent connection. - - Args: - reader (asyncio.StreamReader): Stream reader for the agent. - writer (asyncio.StreamWriter): Stream writer for the agent. - """ - # get the peername of the writer - peername = writer.get_extra_info("peername") - queue_created = False - try: - self.logger.info(f"New connection from {peername}") - # Check if the maximum number of connections has been reached - if self.current_connections < self.max_connections: - # increment the count of current connections - self.current_connections += 1 - self.logger.info(f"New agent connected: {peername}. Current connections: {self.current_connections}") - # Ensure a queue exists for this agent - if peername not in self.answers_queues: - self.answers_queues[peername] = asyncio.Queue(maxsize=2) - queue_created = True - self.logger.info(f"Created queue for agent {peername}") - # Handle the new agent - while True: - # Step 1: Read data from the agent - data = await reader.read(ProtocolConfig.BUFFER_SIZE) - if not data: - self.logger.info(f"Agent {peername} disconnected.") - await self.handle_agent_quit(peername) - break - - raw_message = data.decode().strip() - self.logger.debug(f"Handler received from {peername}: {raw_message}") - - # Step 2: Forward the message to the Coordinator - await self.actions_queue.put((peername, raw_message)) - - # Step 3: Get a matching response from the answers queue - response_queue = self.answers_queues[peername] - response = await response_queue.get() - self.logger.info(f"Sending response to agent {peername}: {response}") - - # Step 4: Send the response to the agent - response = str(response).encode() + ProtocolConfig.END_OF_MESSAGE - writer.write(response) - await writer.drain() - else: - self.logger.warning(f"Queue for agent {peername} already exists. Closing connection.") - else: - self.logger.info(f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}") - except ConnectionResetError: - self.logger.warning(f"Connection reset by {peername}") - await self.handle_agent_quit(peername) - except asyncio.CancelledError: - self.logger.debug("Connection handling cancelled.") - raise # Ensure the exception propagates - except Exception as e: - self.logger.error(f"Unexpected error with client {peername}: {e}") - raise - finally: - try: - if peername in self.answers_queues: - # If the queue was created, remove it - if queue_created: - self.answers_queues.pop(peername) - self.logger.info(f"Removed queue for agent {peername}") - self.current_connections = max(0, self.current_connections - 1) - writer.close() - await writer.wait_closed() - except Exception: - # swallow exceptions on close to avoid crash on cleanup - pass - async def __call__(self, reader, writer): - """ - Allow the server instance to be called as a coroutine. +def convert_msg_dict_to_json(msg_dict: dict) -> str: + """ + Helper function to create text-base messge from a dictionary. Used in the Agent-Game communication. + """ + try: + # Convert message into string representation + output_message = json.dumps(msg_dict) + except Exception as e: + # Let the caller handle logging if needed, or re-raise with context + raise TypeError(f"Error when converting msg to JSON:{e}") from e + return output_message - Args: - reader (asyncio.StreamReader): Stream reader for the agent. - writer (asyncio.StreamWriter): Stream writer for the agent. - """ - await self.handle_new_agent(reader, writer) class GameCoordinator: """ @@ -140,6 +34,7 @@ class GameCoordinator: host (str): Host address for the game server. port (int): Port number for the game server. logger (logging.Logger): Logger for the GameCoordinator. + config_manager (ConfigurationManager): Manager for game configuration. _tasks (set): Set of active asyncio tasks. shutdown_flag (asyncio.Event): Event to signal shutdown. _reset_event (asyncio.Event): Event to signal game reset. @@ -149,10 +44,6 @@ class GameCoordinator: _reset_done_condition (asyncio.Condition): Condition for reset completion. _reset_lock (asyncio.Lock): Lock for reset operations. _agents_lock (asyncio.Lock): Lock for agent operations. - _service_host (str): Host for remote configuration service. - _service_port (int): Port for remote configuration service. - _task_config_file (str): Path to local task configuration file. - ALLOWED_ROLES (list): List of allowed agent roles. _cyst_objects: CYST simulator initialization objects. _cyst_object_string: String representation of CYST objects. _agent_action_queue (asyncio.Queue): Queue for agent actions. @@ -172,10 +63,10 @@ class GameCoordinator: _agent_rewards (dict): Rewards per agent address. _agent_trajectories (dict): Trajectories per agent address. """ - def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, task_config_file:str,allowed_roles=["Attacker", "Defender", "Benign"]) -> None: + def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, task_config_file:str) -> None: self.host = game_host self.port = game_port - self.logger = logging.getLogger("AIDojo-GameCoordinator") + self.logger = logging.getLogger("GameCoordinator") self._tasks = set() self.shutdown_flag = asyncio.Event() @@ -186,17 +77,10 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por self._reset_done_condition = asyncio.Condition() self._reset_lock = asyncio.Lock() self._agents_lock = asyncio.Lock() - - # for accessing configuration remotely - self._service_host = service_host - self._service_port = service_port - # for reading configuration locally - self._task_config_file = task_config_file - self.logger = logging.getLogger("AIDojo-GameCoordinator") - self.ALLOWED_ROLES = allowed_roles - self._cyst_objects = None - self._cyst_object_string = None - + + # Configuration Manager + self.config_manager = ConfigurationManager(task_config_file, service_host, service_port) + # prepare agent communication self._agent_action_queue = asyncio.Queue() self._agent_response_queues = {} @@ -227,7 +111,17 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por self._agent_trajectories = {} def _spawn_task(self, coroutine, *args, **kwargs)->asyncio.Task: - "Helper function to make sure all tasks are registered for proper termination" + """ + Helper function to make sure all tasks are registered for proper termination. + + Args: + coroutine: The coroutine function to schedule. + *args: Positional arguments to pass to the coroutine. + **kwargs: Keyword arguments to pass to the coroutine. + + Returns: + asyncio.Task: The created task object. + """ task = asyncio.create_task(coroutine(*args, **kwargs)) self._tasks.add(task) def remove_task(t): @@ -235,31 +129,23 @@ def remove_task(t): task.add_done_callback(remove_task) # Remove task when done return task - async def shutdown_signal_handler(self): - """Handle shutdown signals.""" + async def shutdown_signal_handler(self)->None: + """ + Logs the signal reception and sets the shutdown flag to initiate graceful termination. + """ self.logger.info("Shutdown signal received. Setting shutdown flag.") self.shutdown_flag.set() async def create_agent_queue(self, agent_addr:tuple)->None: """ Creates a queue for the given agent address if it doesn't already exist. + + Args: + agent_addr (tuple): The agent address to create a queue for. """ if agent_addr not in self._agent_response_queues: self._agent_response_queues[agent_addr] = asyncio.Queue() self.logger.info(f"Created queue for agent {agent_addr}. {len(self._agent_response_queues)} queues in total.") - - def convert_msg_dict_to_json(self, msg_dict:dict)->str: - """ - Helper function to create text-base messge from a dictionary. Used in the Agent-Game communication. - """ - try: - # Convert message into string representation - output_message = json.dumps(msg_dict) - except Exception as e: - self.logger.error(f"Error when converting msg to JSON:{e}") - raise e - # Send to anwer_queue - return output_message def run(self)->None: """ @@ -272,80 +158,6 @@ def run(self)->None: finally: self.logger.info(f"{__class__.__name__} has exited.") - async def _fetch_initialization_objects(self): - """Send a REST request to MAIN and fetch initialization objects of CYST simulator.""" - async with ClientSession() as session: - try: - async with session.get(f"http://{self._service_host}:{self._service_port}/cyst_init_objects") as response: - if response.status == 200: - response = await response.json() - self.logger.debug(response) - env = Environment.create() - self._CONFIG_FILE_HASH = get_str_hash(response) - self._cyst_objects = env.configuration.general.load_configuration(response) - self.logger.debug(f"Initialization objects received:{self._cyst_objects}") - #self.task_config = ConfigParser(config_dict=response["task_configuration"]) - else: - self.logger.error(f"Failed to fetch initialization objects. Status: {response.status}") - except Exception as e: - self.logger.error(f"Error fetching initialization objects: {e}") - # Temporary fix - self.task_config = ConfigParser(self._task_config_file) - - def _load_initialization_objects(self)->None: - """ - Loads task configuration from a local file. - """ - self.task_config = ConfigParser(self._task_config_file) - self._cyst_objects = self.task_config.get_scenario() - self._CONFIG_FILE_HASH = get_str_hash(str(self._cyst_objects)) - - def _get_starting_position_per_role(self)->dict: - """ - Method for finding starting position for each agent role in the game. - """ - starting_positions = {} - for agent_role in self.ALLOWED_ROLES: - try: - starting_positions[agent_role] = self.task_config.get_start_position(agent_role=agent_role) - self.logger.info(f"Starting position for role '{agent_role}': {starting_positions[agent_role]}") - except KeyError: - starting_positions[agent_role] = {} - return starting_positions - - def _get_win_condition_per_role(self)-> dict: - """ - Method for finding wininng conditions for each agent role in the game. - """ - win_conditions = {} - for agent_role in self.ALLOWED_ROLES: - try: - win_conditions[agent_role] = self.task_config.get_win_conditions(agent_role=agent_role) - except KeyError: - win_conditions[agent_role] = {} - self.logger.info(f"Win condition for role '{agent_role}': {win_conditions[agent_role]}") - return win_conditions - - def _get_goal_description_per_role(self)->dict: - """ - Method for finding goal description for each agent role in the game. - """ - goal_descriptions ={} - for agent_role in self.ALLOWED_ROLES: - try: - goal_descriptions[agent_role] = self.task_config.get_goal_description(agent_role=agent_role) - except KeyError: - goal_descriptions[agent_role] = "" - self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}") - return goal_descriptions - - def _get_max_steps_per_role(self)->dict: - """ - Method for finding max amount of steps in 1 episode for each agent role in the game. - """ - max_steps = {role:self.task_config.get_max_steps(role) for role in self.ALLOWED_ROLES} - return max_steps - async def start_tcp_server(self): """ Starts TPC sever for the agent communication. @@ -395,32 +207,29 @@ async def start_tasks(self): ) - # initialize the game objects - if self._service_host: #get the task config using REST API - self.logger.info(f"Fetching task configuration from {self._service_host}:{self._service_port}") - await self._fetch_initialization_objects() - elif self._task_config_file: # load task config locally from a file - self.logger.info(f"Loading task configuration from file: {self._task_config_file}") - self._load_initialization_objects() - else: - raise ValueError("Task configuration not specified") + # Initialize configuration manager and load the configuration + await self.config_manager.load() + self._cyst_objects = self.config_manager.get_cyst_objects() + + if self.config_manager.get_config_hash(): + self._CONFIG_FILE_HASH = self.config_manager.get_config_hash() - # Read configuration - self._starting_positions_per_role = self._get_starting_position_per_role() - self._win_conditions_per_role = self._get_win_condition_per_role() - self._goal_description_per_role = self._get_goal_description_per_role() - self._steps_limit_per_role = self._get_max_steps_per_role() + self._starting_positions_per_role = self.config_manager.get_all_starting_positions() + self._win_conditions_per_role = self.config_manager.get_all_win_conditions() + self._goal_description_per_role = self.config_manager.get_all_goal_descriptions() + self._steps_limit_per_role = self.config_manager.get_all_max_steps() + self.logger.debug(f"Timeouts set to:{self._steps_limit_per_role}") - if self.task_config.get_use_global_defender(): + if self.config_manager.get_use_global_defender(): self._global_defender = GlobalDefender() else: self._global_defender = None - self._use_dynamic_ips = self.task_config.get_use_dynamic_addresses() + self._use_dynamic_ips = self.config_manager.get_use_dynamic_ips() self.logger.info(f"Change IP every episode set to: {self._use_dynamic_ips}") - self._rewards = self.task_config.get_rewards(["step", "success", "fail", "false_positive"]) + self._rewards = self.config_manager.get_rewards(["step", "success", "fail", "false_positive"]) self.logger.info(f"Rewards set to:{self._rewards}") - self._min_required_players = self.task_config.get_required_num_players() + self._min_required_players = self.config_manager.get_required_num_players() self.logger.info(f"Min player requirement set to:{self._min_required_players}") # run self initialization self._initialize() @@ -447,42 +256,64 @@ async def start_tasks(self): await asyncio.gather(*self._tasks, return_exceptions=True) # Wait for all tasks to finish self.logger.info("All tasks shut down.") + def _parse_action_message(self, agent_addr: tuple, message: str) -> Optional[Action]: + """ + Parses a JSON message from an agent into an Action object. + + Args: + agent_addr (tuple): The address of the agent sending the message (used for logging context). + message (str): The raw JSON string message received from the agent. + + Returns: + Optional[Action]: The parsed Action object if successful, None otherwise. + """ + try: + action = Action.from_json(message) + return action + except Exception as e: + self.logger.error(f"Error when converting msg from {agent_addr} to Action using Action.from_json():{e}, {message}") + return None + + def _dispatch_action(self, agent_addr: tuple, action: Action) -> None: + """ + Dispatches an Action to the appropriate processing method based on its type. + + Args: + agent_addr (tuple): The address of the agent performing the action. + action (Action): The Action object to be processed. + """ + match action.type: + case ActionType.JoinGame: + self.logger.debug(f"[{agent_addr}] Start processing of ActionType.JoinGame") + self._spawn_task(self._process_join_game_action, agent_addr, action) + case ActionType.QuitGame: + self.logger.debug(f"[{agent_addr}] Start processing of ActionType.QuitGame") + self._spawn_task(self._process_quit_game_action, agent_addr) + case ActionType.ResetGame: + self.logger.debug(f"[{agent_addr}] Start processing of ActionType.ResetGame") + self._spawn_task(self._process_reset_game_action, agent_addr, action) + case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService | ActionType.BlockIP: + self.logger.debug(f"[{agent_addr}] Start processing of {action.type}") + self._spawn_task(self._process_game_action, agent_addr, action) + case _: + self.logger.warning(f"[{agent_addr}] Unsupported action type: {action}!") + async def run_game(self): """ - Task responsible for reading messages from the agent queue and processing them based on the ActionType. + Main game loop task. + + Responsible for reading messages from the agent queue, parsing them using `_parse_action_message`, + and dispatching them to the appropriate handler using `_dispatch_action`. """ while not self.shutdown_flag.is_set(): # Read message from the queue agent_addr, message = await self._agent_action_queue.get() if message is not None: self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.") - - try: # Convert message to Action - action = Action.from_json(message) - self.logger.debug(f"\tConverted to: {action}.") - match action.type: # process action based on its type - case ActionType.JoinGame: - self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.JoinGame by {agent_addr}") - self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}") - self._spawn_task(self._process_join_game_action, agent_addr, action) - case ActionType.QuitGame: - self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.QuitGame by {agent_addr}") - self._spawn_task(self._process_quit_game_action, agent_addr) - case ActionType.ResetGame: - self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.ResetGame by {agent_addr}") - self._spawn_task(self._process_reset_game_action, agent_addr, action) - case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService: - self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") - self._spawn_task(self._process_game_action, agent_addr, action) - case ActionType.BlockIP: - self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") - self._spawn_task(self._process_game_action, agent_addr, action) - case _: - self.logger.warning(f"About agent {agent_addr}. Unsupported action type: {action}!") - except Exception as e: - self.logger.error( - f"Error when converting msg to Action using Action.from_json():{e}, {message}" - ) + + action = self._parse_action_message(agent_addr, message) + if action: + self._dispatch_action(agent_addr, action) self.logger.info("\tAction processing task stopped.") async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None: @@ -498,7 +329,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No if agent_addr not in self.agents: agent_name = action.parameters["agent_info"].name agent_role = action.parameters["agent_info"].role - if agent_role in self.ALLOWED_ROLES: + if agent_role in AgentRole: # add agent to the world new_agent_game_state, new_agent_goal_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role]) if new_agent_game_state: # successful registration @@ -530,7 +361,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No if hasattr(self, "_registration_info"): for key, value in self._registration_info.items(): output_message_dict["message"][key] = value - await self._agent_response_queues[agent_addr].put(self.convert_msg_dict_to_json(output_message_dict)) + await self._agent_response_queues[agent_addr].put(convert_msg_dict_to_json(output_message_dict)) else: self.logger.info( f"\tError in registration, unknown agent role: {agent_role}!" @@ -540,7 +371,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No "status": str(GameStatus.BAD_REQUEST), "message": f"Incorrect agent_role {agent_role}", } - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + response_msg_json = convert_msg_dict_to_json(output_message_dict) await self._agent_response_queues[agent_addr].put(response_msg_json) else: self.logger.info("\tError in registration, agent already exists!") @@ -549,7 +380,7 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No "status": str(GameStatus.BAD_REQUEST), "message": "Agent already exists.", } - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + response_msg_json = convert_msg_dict_to_json(output_message_dict) await self._agent_response_queues[agent_addr].put(response_msg_json) except asyncio.CancelledError: self.logger.debug(f"Proccessing JoinAction of agent {agent_addr} interrupted") @@ -616,7 +447,7 @@ async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Actio if "request_trajectory" in reset_action.parameters and reset_action.parameters["request_trajectory"]: output_message_dict["message"]["last_trajectory"] = self._agent_trajectories[agent_addr] self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + response_msg_json = convert_msg_dict_to_json(output_message_dict) await self._agent_response_queues[agent_addr].put(response_msg_json) async def _process_game_action(self, agent_addr: tuple, action:Action)->None: @@ -687,7 +518,7 @@ async def _process_game_action(self, agent_addr: tuple, action:Action)->None: "observation": observation_as_dict(new_observation), "status": str(GameStatus.OK), } - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + response_msg_json = convert_msg_dict_to_json(output_message_dict) await self._agent_response_queues[agent_addr].put(response_msg_json) async def _assign_rewards_episode_end(self): @@ -755,7 +586,7 @@ async def _reset_game(self): self.logger.info("Resetting game to initial state.") await self.reset() for agent in self.agents: - if self.task_config.get_store_trajectories(): + if self.config_manager.get_store_trajectories(): async with self._agents_lock: self._store_trajectory_to_file(agent) self.logger.debug(f"Resetting agent {agent}") @@ -799,7 +630,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState self._agent_last_action[agent_addr] = None self._agent_rewards[agent_addr] = 0 self._agent_false_positives[agent_addr] = 0 - if agent_role.lower() == "attacker": + if agent_role in [AgentRole.Attacker]: self._agent_status[agent_addr] = AgentStatus.PlayingWithTimeout else: self._agent_status[agent_addr] = AgentStatus.Playing @@ -808,7 +639,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState # create initial observation return Observation(self._agent_states[agent_addr], 0, False, {}) - async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: + async def register_agent(self, agent_id:tuple, agent_role:AgentRole, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: """ Domain specific method of the environment. Creates the initial state of the agent. """ @@ -820,7 +651,7 @@ async def remove_agent(self, agent_id:tuple, agent_state:GameState)->bool: """ raise NotImplementedError - async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: + async def reset_agent(self, agent_id:tuple, agent_role:AgentRole, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: raise NotImplementedError async def _remove_agent_from_game(self, agent_addr): @@ -859,16 +690,25 @@ async def _remove_agent_from_game(self, agent_addr): return agent_info async def step(self, agent_id:tuple, agent_state:GameState, action:Action): + """ + Domain specific method of the environment. Creates the initial state of the agent. + Must be implemented by the domain specific environment. + """ raise NotImplementedError async def reset(self)->bool: - return NotImplemented + """ + Domain specific method of the environment. Creates the initial state of the agent. + Must be implemented by the domain specific environment. + """ + raise NotImplementedError def _initialize(self): """ Initialize the game state and other necessary components. This is called at the start of the game after the configuration is loaded. + Must be implemented by the domain specific environment. """ - return NotImplemented + raise NotImplementedError def goal_check(self, agent_addr:tuple)->bool: """ @@ -924,6 +764,8 @@ def is_timeout(self, agent:tuple)->bool: def add_false_positive(self, agent:tuple)->None: """ Method for adding false positive to the agent. + Args: + agent (tuple): The agent to add false positive to. """ self.logger.debug(f"Adding false positive to {agent}") if agent in self._agent_false_positives: @@ -935,6 +777,10 @@ def add_false_positive(self, agent:tuple)->None: def _update_agent_status(self, agent:tuple)->AgentStatus: """ Update the status of an agent based on reaching the goal, timeout or detection. + Args: + agent (tuple): The agent to update the status of. + Returns: + AgentStatus: The new status of the agent. """ # read current status of the agent next_status = self._agent_status[agent] @@ -953,6 +799,13 @@ def _update_agent_status(self, agent:tuple)->AgentStatus: return next_status def _update_agent_episode_end(self, agent:tuple)->bool: + """ + Update the episode end status of an agent. + Args: + agent (tuple): The agent to update the episode end status of. + Returns: + bool: True if the episode has ended, False otherwise. + """ episode_end = False if self._agent_status[agent] in [AgentStatus.Success, AgentStatus.Fail, AgentStatus.TimeoutReached]: # agent reached goal, timeout or was detected @@ -1012,4 +865,5 @@ def is_agent_benign(self, agent_addr:tuple)->bool: """ if agent_addr not in self.agents: return False + #TODO: change to use AgentRole return self.agents[agent_addr][1].lower() in ["defender", "benign"] \ No newline at end of file diff --git a/AIDojoCoordinator/global_defender.py b/netsecgame/game/global_defender.py similarity index 98% rename from AIDojoCoordinator/global_defender.py rename to netsecgame/game/global_defender.py index c784b11a..28cf3c0c 100644 --- a/AIDojoCoordinator/global_defender.py +++ b/netsecgame/game/global_defender.py @@ -1,6 +1,6 @@ # Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz from itertools import groupby -from AIDojoCoordinator.game_components import ActionType, Action +from netsecgame.game_components import ActionType, Action from random import random diff --git a/AIDojoCoordinator/netsecenv_conf.yaml b/netsecgame/game/netsecenv_conf.yaml similarity index 100% rename from AIDojoCoordinator/netsecenv_conf.yaml rename to netsecgame/game/netsecenv_conf.yaml diff --git a/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml b/netsecgame/game/netsecevn_conf_cyst_integration.yaml similarity index 100% rename from AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml rename to netsecgame/game/netsecevn_conf_cyst_integration.yaml diff --git a/AIDojoCoordinator/scenarios/__init__.py b/netsecgame/game/scenarios/__init__.py similarity index 100% rename from AIDojoCoordinator/scenarios/__init__.py rename to netsecgame/game/scenarios/__init__.py diff --git a/AIDojoCoordinator/scenarios/one_net.py b/netsecgame/game/scenarios/one_net.py similarity index 100% rename from AIDojoCoordinator/scenarios/one_net.py rename to netsecgame/game/scenarios/one_net.py diff --git a/AIDojoCoordinator/scenarios/scenario_configuration.py b/netsecgame/game/scenarios/scenario_configuration.py similarity index 100% rename from AIDojoCoordinator/scenarios/scenario_configuration.py rename to netsecgame/game/scenarios/scenario_configuration.py diff --git a/AIDojoCoordinator/scenarios/smaller_scenario_configuration.py b/netsecgame/game/scenarios/smaller_scenario_configuration.py similarity index 100% rename from AIDojoCoordinator/scenarios/smaller_scenario_configuration.py rename to netsecgame/game/scenarios/smaller_scenario_configuration.py diff --git a/AIDojoCoordinator/scenarios/test_scenario_configuration.py b/netsecgame/game/scenarios/test_scenario_configuration.py similarity index 100% rename from AIDojoCoordinator/scenarios/test_scenario_configuration.py rename to netsecgame/game/scenarios/test_scenario_configuration.py diff --git a/AIDojoCoordinator/scenarios/three_net_scenario.py b/netsecgame/game/scenarios/three_net_scenario.py similarity index 100% rename from AIDojoCoordinator/scenarios/three_net_scenario.py rename to netsecgame/game/scenarios/three_net_scenario.py diff --git a/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py b/netsecgame/game/scenarios/tiny_scenario_configuration.py similarity index 100% rename from AIDojoCoordinator/scenarios/tiny_scenario_configuration.py rename to netsecgame/game/scenarios/tiny_scenario_configuration.py diff --git a/AIDojoCoordinator/scenarios/two_nets.py b/netsecgame/game/scenarios/two_nets.py similarity index 100% rename from AIDojoCoordinator/scenarios/two_nets.py rename to netsecgame/game/scenarios/two_nets.py diff --git a/AIDojoCoordinator/scenarios/two_nets_small.py b/netsecgame/game/scenarios/two_nets_small.py similarity index 100% rename from AIDojoCoordinator/scenarios/two_nets_small.py rename to netsecgame/game/scenarios/two_nets_small.py diff --git a/AIDojoCoordinator/scenarios/two_nets_tiny.py b/netsecgame/game/scenarios/two_nets_tiny.py similarity index 100% rename from AIDojoCoordinator/scenarios/two_nets_tiny.py rename to netsecgame/game/scenarios/two_nets_tiny.py diff --git a/AIDojoCoordinator/worlds/CYSTCoordinator.py b/netsecgame/game/worlds/CYSTCoordinator.py similarity index 88% rename from AIDojoCoordinator/worlds/CYSTCoordinator.py rename to netsecgame/game/worlds/CYSTCoordinator.py index f9d65c74..0e023ba6 100644 --- a/AIDojoCoordinator/worlds/CYSTCoordinator.py +++ b/netsecgame/game/worlds/CYSTCoordinator.py @@ -8,15 +8,40 @@ import logging import argparse from pathlib import Path -from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Service -from AIDojoCoordinator.coordinator import GameCoordinator +from netsecgame.game_components import GameState, Action, ActionType, IP, Service ,Network +from netsecgame.game.coordinator import GameCoordinator +from cyst.api.configuration.network.node import NodeConfig -from AIDojoCoordinator.utils.utils import get_starting_position_from_cyst_config, get_logging_level +from netsecgame.utils.utils import get_logging_level + +def get_starting_position_from_cyst_config(cyst_objects): + """ + Extracts starting positions from CYST configuration objects. + + Args: + cyst_objects (list): List of CYST configuration objects. + Returns: + dict: A dictionary mapping agent identifiers to their starting known hosts and networks. + """ + starting_positions = {} + for obj in cyst_objects: + if isinstance(obj, NodeConfig): + for active_service in obj.active_services: + if active_service.type == "netsecenv_agent": + print(f"starting processing {obj.id}.{active_service.name}") + hosts = set() + networks = set() + for interface in obj.interfaces: + hosts.add(IP(str(interface.ip))) + net_ip, net_mask = str(interface.net).split("/") + networks.add(Network(net_ip,int(net_mask))) + starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks} + return starting_positions class CYSTCoordinator(GameCoordinator): def __init__(self, game_host:str, game_port:int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"]): - super().__init__(game_host, game_port, service_host, service_port, allowed_roles) + super().__init__(game_host, game_port, service_host, service_port, task_config_file=None) self._id_to_cystid = {} self._cystid_to_id = {} self._known_agent_roles = {} diff --git a/AIDojoCoordinator/worlds/NSEGameCoordinator.py b/netsecgame/game/worlds/NetSecGame.py similarity index 97% rename from AIDojoCoordinator/worlds/NSEGameCoordinator.py rename to netsecgame/game/worlds/NetSecGame.py index 76222313..94b4c9d1 100644 --- a/AIDojoCoordinator/worlds/NSEGameCoordinator.py +++ b/netsecgame/game/worlds/NetSecGame.py @@ -9,19 +9,19 @@ import json from faker import Faker from pathlib import Path -from typing import Iterable +from typing import Iterable, Any from collections import defaultdict -from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service -from AIDojoCoordinator.coordinator import GameCoordinator +from netsecgame.game_components import GameState, Action, ActionType, IP, Network, Data, Service, AgentRole +from netsecgame.game.coordinator import GameCoordinator from cyst.api.configuration import NodeConfig, RouterConfig, ConnectionConfig, ExploitConfig, FirewallPolicy -from AIDojoCoordinator.utils.utils import get_logging_level +from netsecgame.utils.utils import get_logging_level -class NSGCoordinator(GameCoordinator): +class NetSecGame(GameCoordinator): - def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=None): - super().__init__(game_host, game_port, service_host=None, service_port=None, allowed_roles=allowed_roles, task_config_file=task_config) + def __init__(self, game_host, game_port, task_config:str, seed=None): + super().__init__(game_host, game_port, service_host=None, service_port=None, task_config_file=task_config) # Internal data structure of the NSG self._ip_to_hostname = {} # Mapping of `IP`:`host_name`(str) of all nodes in the environment @@ -57,17 +57,20 @@ def _initialize(self): None """ # Load CYST configuration - self._process_cyst_config(self._cyst_objects) - # Check if dynamic network and ip adddresses are required - if self._use_dynamic_ips: - self.logger.info("Dynamic change of the IP and network addresses enabled") - self._faker_object = Faker() - Faker.seed(self._seed) - # store initial values for parts which are modified during the game - self._data_original = copy.deepcopy(self._data) - self._data_content_original = copy.deepcopy(self._data_content) - self._firewall_original = copy.deepcopy(self._firewall) - self.logger.info("Environment initialization finished") + if self._cyst_objects is not None: + self._process_cyst_config(self._cyst_objects) + # Check if dynamic network and ip adddresses are required + if self._use_dynamic_ips: + self.logger.info("Dynamic change of the IP and network addresses enabled") + self._faker_object = Faker() + Faker.seed(self._seed) + # store initial values for parts which are modified during the game + self._data_original = copy.deepcopy(self._data) + self._data_content_original = copy.deepcopy(self._data_content) + self._firewall_original = copy.deepcopy(self._firewall) + self.logger.info("Environment initialization finished") + else: + self.logger.error("CYST configuration not loaded, cannot initialize the environment!") def _get_hosts_from_view(self, view_hosts:Iterable, allowed_hosts=None)->set[IP]: """ @@ -295,7 +298,7 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga self.logger.info(f"Generated GameState:{game_state}") return game_state - def _process_cyst_config(self, configuration_objects:list)-> None: + def _process_cyst_config(self, configuration_objects:list[Any])-> None: """ Process the cyst configuration file """ @@ -403,7 +406,7 @@ def process_firewall()->dict: for ips in self._networks.values(): all_ips.update(ips) firewall = {ip:set() for ip in all_ips} - if self.task_config.get_use_firewall(): + if self.config_manager.get_use_firewall(): self.logger.info("Firewall enabled - processing FW rules") # LOCAL NETWORKS for net, ips in self._networks.items(): @@ -665,7 +668,7 @@ def _get_services_from_host(self, host_ip:str, controlled_hosts:set)-> set: """ Returns set of Service tuples from given hostIP """ - found_services = {} + found_services = set() if host_ip in self._ip_to_hostname: #is it existing IP? if self._ip_to_hostname[host_ip] in self._services: #does it have any services? if host_ip in controlled_hosts: # Should local services be included ? @@ -712,7 +715,7 @@ def _get_known_blocks_in_host(self, host_ip:str, controlled_hosts:set)->set: self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.") return known_blocks - def _get_data_content(self, host_ip:str, data_id:str)->str: + def _get_data_content(self, host_ip:str, data_id:str)->str|None: """ Returns content of data identified by a host_ip and data_ip. """ @@ -756,7 +759,7 @@ def _execute_action(self, current_state:GameState, action:Action, agent_id:tuple raise ValueError(f"Unknown Action type or other error: '{action.type}'") return next_state - def _record_false_positive(self, src_ip:IP, dst_ip:IP, agent_id:tuple)->bool: + def _record_false_positive(self, src_ip:IP, dst_ip:IP, agent_id:tuple)-> None: # only record false positive if the agent is benign if self.is_agent_benign(agent_id): # find agent(s) that created the rule @@ -1057,7 +1060,7 @@ def update_log_file(self, known_data:set, action, target_host:IP): new_content = json.dumps(new_content) self._data[hostaname].add(Data(owner="system", id="logfile", type="log", size=len(new_content) , content= new_content)) - async def register_agent(self, agent_id, agent_role, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: + async def register_agent(self, agent_id, agent_role:AgentRole, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: start_game_state = self._create_state_from_view(agent_initial_view) goal_state = self._create_goal_state_from_view(agent_win_condition_view) return start_game_state, goal_state @@ -1069,7 +1072,7 @@ async def remove_agent(self, agent_id, agent_state)->bool: async def step(self, agent_id, agent_state, action)->GameState: return self._execute_action(agent_state, action, agent_id) - async def reset_agent(self, agent_id, agent_role, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: + async def reset_agent(self, agent_id, agent_role:AgentRole, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: game_state = self._create_state_from_view(agent_initial_view) goal_state = self._create_goal_state_from_view(agent_win_condition_view) return game_state, goal_state @@ -1083,7 +1086,7 @@ async def reset(self)->bool: self.logger.info('--- Reseting NSG Environment to its initial state ---') # change IPs if needed # This is done ONLY if it is (i) enabled in the task config and (ii) all agents requested it - if self.task_config.get_use_dynamic_addresses(): + if self.config_manager.get_use_dynamic_ips(): if all(self._randomize_topology_requests.values()): self.logger.info("All agents requested reset with randomized topology.") self._dynamic_ip_change() @@ -1173,6 +1176,6 @@ async def reset(self)->bool: level=pass_level, ) - game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config, seed=args.seed) + game_server = NetSecGame(args.game_host, args.game_port, args.task_config, seed=args.seed) # Run it! game_server.run() \ No newline at end of file diff --git a/AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py b/netsecgame/game/worlds/RealWorldNetSecGame.py similarity index 95% rename from AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py rename to netsecgame/game/worlds/RealWorldNetSecGame.py index 763326ab..08a4fb50 100644 --- a/AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py +++ b/netsecgame/game/worlds/RealWorldNetSecGame.py @@ -9,11 +9,11 @@ import os from pathlib import Path -from AIDojoCoordinator.utils.utils import get_logging_level -from AIDojoCoordinator.game_components import GameState, Action, ActionType, Service,IP -from AIDojoCoordinator.worlds.NSEGameCoordinator import NSGCoordinator +from netsecgame.utils.utils import get_logging_level +from netsecgame.game_components import GameState, Action, ActionType, Service,IP +from netsecgame.game.worlds.NetSecGame import NetSecGame -class NSERealWorldGameCoordinator(NSGCoordinator): +class RealWorldNetSecGame(NetSecGame): def _execute_action(self, current_state:GameState, action:Action)-> GameState: """ @@ -188,6 +188,6 @@ def _execute_find_services_action_real_world(self, current_state:GameState, acti level=pass_level, ) - game_server = NSERealWorldGameCoordinator(args.game_host, args.game_port, args.task_config) + game_server = RealWorldNetSecGame(args.game_host, args.game_port, args.task_config) # Run it! game_server.run() \ No newline at end of file diff --git a/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py b/netsecgame/game/worlds/WhiteBoxNetSecGame.py similarity index 90% rename from AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py rename to netsecgame/game/worlds/WhiteBoxNetSecGame.py index d9ac3285..b6135fd6 100644 --- a/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py +++ b/netsecgame/game/worlds/WhiteBoxNetSecGame.py @@ -1,19 +1,18 @@ +# Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import itertools import argparse import logging import os import json from pathlib import Path -from AIDojoCoordinator.utils.utils import get_logging_level -from AIDojoCoordinator.game_components import Action, ActionType -from AIDojoCoordinator.worlds.NSEGameCoordinator import NSGCoordinator +from netsecgame.utils.utils import get_logging_level +from netsecgame.game_components import Action, ActionType +from netsecgame.game.worlds.NetSecGame import NetSecGame - - -class WhiteBoxNSGCoordinator(NSGCoordinator): +class WhiteBoxNetSecGame(NetSecGame): """ - WhiteBoxNSGCoordinator is an extension for the NetSecGame environment + WhiteBoxNetSecGame is an extension for the NetSecGame environment that provides list of all possible actions to each agent that registers in the game. """ def __init__(self, game_host, game_port, task_config, allowed_roles=["Attacker", "Defender", "Benign"], seed=42, include_block_action=False): @@ -26,15 +25,17 @@ def _initialize(self): super()._initialize() # All components are initialized, now we can set the action mapping self.logger.debug("Creating action mapping for the game.") - self._generate_all_actions() + self._all_actions = self._generate_all_actions() self._registration_info = { "all_actions": json.dumps([v.as_dict for v in self._all_actions]), } if self._all_actions is not None else {} - def _generate_all_actions(self)-> list: + def _generate_all_actions(self)-> list[Action]: """ Generate a list of all possible actions for the game. + Returns: + list[Action]: List of all possible actions. """ actions = [] all_ips = [self._ip_mapping[ip] for ip in self._ip_to_hostname.keys()] @@ -119,7 +120,7 @@ def _generate_all_actions(self)-> list: self.logger.info(f"Created action mapping with {len(actions)} actions.") for action in actions: self.logger.debug(action) - self._all_actions = actions + return actions def _create_state_from_view(self, view, add_neighboring_nets = True): @@ -199,6 +200,6 @@ def _create_state_from_view(self, view, add_neighboring_nets = True): level=pass_level, ) - game_server = WhiteBoxNSGCoordinator(args.game_host, args.game_port, args.task_config, seed=args.seed) + game_server = WhiteBoxNetSecGame(args.game_host, args.game_port, args.task_config, seed=args.seed) # Run it! game_server.run() \ No newline at end of file diff --git a/AIDojoCoordinator/utils/__init__.py b/netsecgame/game/worlds/__init__.py similarity index 100% rename from AIDojoCoordinator/utils/__init__.py rename to netsecgame/game/worlds/__init__.py diff --git a/AIDojoCoordinator/game_components.py b/netsecgame/game_components.py similarity index 73% rename from AIDojoCoordinator/game_components.py rename to netsecgame/game_components.py index 5ec6a505..4b3d3ec1 100755 --- a/AIDojoCoordinator/game_components.py +++ b/netsecgame/game_components.py @@ -1,9 +1,9 @@ +from __future__ import annotations # Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz # Library of helpful functions and objects to play the net sec game from dataclasses import dataclass, field, asdict -from typing import Dict, Any +from typing import Dict, Any, List, Set, Tuple, NamedTuple import dataclasses -from collections import namedtuple import json import enum import sys @@ -11,17 +11,16 @@ import ipaddress import ast - -@dataclass(frozen=True, eq=True, order=True) +@dataclass(frozen=True, eq=True, order=True, slots=True) class Service(): """ Represents a service in the NetSecGame. Attributes: name (str): Name of the service. - type (str): Type of the service. Default `uknown` - version (str): Version of the service. Default `uknown` - is_local (bool): Whether the service is local. Default True + type (str): Type of the service. Default `unknown`. + version (str): Version of the service. Default `unknown`. + is_local (bool): Whether the service is local. Default True. """ name: str type: str = "unknown" @@ -29,12 +28,12 @@ class Service(): is_local: bool = True @classmethod - def from_dict(cls, data: dict)->"Service": + def from_dict(cls, data: Dict[str, Any]) -> Service: """ Create a Service object from a dictionary. Args: - data (dict): Dictionary with service attributes. + data (Dict[str, Any]): Dictionary with service attributes. Returns: Service: The created Service object. @@ -42,7 +41,7 @@ def from_dict(cls, data: dict)->"Service": return cls(**data) -@dataclass(frozen=True, eq=True, order=True) +@dataclass(frozen=True, eq=True, order=True, slots=True) class IP(): """ Immutable object representing an IPv4 address in the NetSecGame. @@ -64,7 +63,7 @@ def __post_init__(self): except ValueError: raise ValueError(f"Invalid IP address provided: {self.ip}") - def __repr__(self)->str: + def __repr__(self) -> str: """ Return the string representation of the IP. @@ -73,26 +72,26 @@ def __repr__(self)->str: """ return self.ip - def __eq__(self, other)->bool: + def __eq__(self, other: object) -> bool: """ Check equality with another IP object. Args: - other (IP): Another IP object. + other (object): Another object to compare with. Returns: - is_equal: True if equal, False otherwise. + bool: True if equal, False otherwise. """ if not isinstance(other, IP): return NotImplemented return self.ip == other.ip - def is_private(self)->bool: + def is_private(self) -> bool: """ Check if the IP address is private. Uses ipaddress module. Returns: - is_private: True if the IP is private, False otherwise. + bool: True if the IP is private, False otherwise. """ try: return ipaddress.IPv4Network(self.ip).is_private @@ -104,12 +103,12 @@ def is_private(self)->bool: return False @classmethod - def from_dict(cls, data: dict)->"IP": + def from_dict(cls, data: Dict[str, Any]) -> IP: """ Build the IP object from a dictionary representation. Args: - data (dict): Dictionary with IP attributes. + data (Dict[str, Any]): Dictionary with IP attributes. Returns: IP: The created IP object. @@ -125,7 +124,7 @@ def __hash__(self)->int: """ return hash(self.ip) -@dataclass(frozen=True, eq=True) +@dataclass(frozen=True, eq=True, slots=True) class Network(): """ Immutable object representing an IPv4 network in the NetSecGame. @@ -137,7 +136,7 @@ class Network(): ip: str mask: int - def __repr__(self)->str: + def __repr__(self) -> str: """ Return the string representation of the network. @@ -146,7 +145,7 @@ def __repr__(self)->str: """ return f"{self.ip}/{self.mask}" - def __str__(self)->str: + def __str__(self) -> str: """ Return the string representation of the network. @@ -155,7 +154,7 @@ def __str__(self)->str: """ return f"{self.ip}/{self.mask}" - def __lt__(self, other)->bool: + def __lt__(self, other: Network) -> bool: """ Less-than comparison for networks. @@ -167,10 +166,10 @@ def __lt__(self, other)->bool: """ try: return netaddr.IPNetwork(str(self)) < netaddr.IPNetwork(str(other)) - except netaddr.core.AddrFormatError: + except netaddr.AddrFormatError: return str(self.ip) < str(other.ip) - def __le__(self, other)->bool: + def __le__(self, other: Network) -> bool: """ Less-than-or-equal comparison for networks. @@ -182,10 +181,10 @@ def __le__(self, other)->bool: """ try: return netaddr.IPNetwork(str(self)) <= netaddr.IPNetwork(str(other)) - except netaddr.core.AddrFormatError: + except netaddr.AddrFormatError: return str(self.ip) <= str(other.ip) - def __gt__(self, other)->bool: + def __gt__(self, other: Network) -> bool: """ Greater-than comparison for networks. @@ -197,10 +196,10 @@ def __gt__(self, other)->bool: """ try: return netaddr.IPNetwork(str(self)) > netaddr.IPNetwork(str(other)) - except netaddr.core.AddrFormatError: + except netaddr.AddrFormatError: return str(self.ip) > str(other.ip) - def is_private(self)->bool: + def is_private(self) -> bool: """ Check if the network is private. Uses ipaddress module. @@ -214,19 +213,20 @@ def is_private(self)->bool: return True @classmethod - def from_dict(cls, data: dict)->"Network": + def from_dict(cls, data: Dict[str, Any]) -> Network: """ Build the Network object from a dictionary. Args: - data (dict): Dictionary with network attributes. + data (Dict[str, Any]): Dictionary with network attributes. Returns: Network: The created Network object. """ return cls(**data) -@dataclass(frozen=True, eq=True, order=True) + +@dataclass(frozen=True, eq=True, order=True, slots=True) class Data(): """ Represents a data object in the NetSecGame. @@ -252,13 +252,14 @@ def __hash__(self) -> int: int: The hash value. """ return hash((self.owner, self.id, self.type)) + @classmethod - def from_dict(cls, data: dict)->"Data": + def from_dict(cls, data: Dict[str, Any]) -> Data: """ Build the Data object from a dictionary. Args: - data (dict): Dictionary with data attributes. + data (Dict[str, Any]): Dictionary with data attributes. Returns: Data: The created Data object. @@ -280,7 +281,7 @@ class ActionType(enum.Enum): QuitGame = "QuitGame" ResetGame = "ResetGame" - def to_string(self)->str: + def to_string(self) -> str: """ Convert the ActionType enum to string. @@ -289,12 +290,12 @@ def to_string(self)->str: """ return self.value - def __eq__(self, other)->bool: + def __eq__(self, other: object) -> bool: """ Compare ActionType with another ActionType or string. Args: - other (ActionType or str): The object to compare. + other (object): The object to compare. Returns: bool: True if equal, False otherwise. @@ -307,7 +308,7 @@ def __eq__(self, other)->bool: return self.value == other.replace("ActionType.", "") return False - def __hash__(self)->int: + def __hash__(self) -> int: """ Compute the hash of the ActionType. @@ -318,7 +319,7 @@ def __hash__(self)->int: return hash(self.value) @classmethod - def from_string(cls, name)->"ActionType": + def from_string(cls, name: str) -> ActionType: """ Convert a string to an ActionType enum. Strips 'ActionType.' if present. @@ -350,7 +351,7 @@ class AgentInfo(): name: str role: str - def __repr__(self)->str: + def __repr__(self) -> str: """ Return the string representation of the AgentInfo. @@ -361,19 +362,22 @@ def __repr__(self)->str: @classmethod - def from_dict(cls, data: dict)->"AgentInfo": + def from_dict(cls, data: Dict[str, Any]) -> AgentInfo: """ Build the AgentInfo object from a dictionary. Args: - data (dict): Dictionary with agent info attributes. + data (Dict[str, Any]): Dictionary with agent info attributes. Returns: AgentInfo: The created AgentInfo object. """ - return cls(**data) + if isinstance(data, str): + data = ast.literal_eval(data) + processed = {"name": data["name"], "role": AgentRole.from_string(data["role"])} + return cls(**processed) -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Action: """ Immutable dataclass representing an Action. @@ -395,16 +399,20 @@ def as_dict(self) -> Dict[str, Any]: """ params = {} for k, v in self.parameters.items(): - if hasattr(v, '__dict__'): # Handle custom objects like Service, Data, AgentInfo + # Check if v is a dataclass AND ensuring it is an instance, not the class itself + if dataclasses.is_dataclass(v) and not isinstance(v, type): params[k] = asdict(v) + elif isinstance(v, dict): + params[k] = v elif isinstance(v, bool): # Handle boolean values params[k] = v else: params[k] = str(v) return {"action_type": str(self.action_type), "parameters": params} + @property - def type(self)->ActionType: + def type(self) -> ActionType: """ Return the action type. @@ -423,7 +431,7 @@ def to_json(self) -> str: return json.dumps(self.as_dict) @classmethod - def from_dict(cls, data_dict: Dict[str, Any]) -> "Action": + def from_dict(cls, data_dict: Dict[str, Any]) -> Action: """ Create an Action from a dictionary. @@ -437,7 +445,7 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action": ValueError: If an unsupported parameter is encountered. """ action_type = ActionType.from_string(data_dict["action_type"]) - params = {} + params: Dict[str, Any] = {} for k, v in data_dict["parameters"].items(): match k: case "source_host" | "target_host" | "blocked_host": @@ -460,7 +468,7 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action": return cls(action_type=action_type, parameters=params) @classmethod - def from_json(cls, json_string: str) -> "Action": + def from_json(cls, json_string: str) -> Action: """ Create an Action from a JSON string. @@ -519,33 +527,34 @@ def __hash__(self) -> int: sorted_params = tuple(sorted((k, hash(v)) for k, v in self.parameters.items())) return hash((self.action_type, sorted_params)) -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class GameState(): """ Represents the state of the game. Attributes: - controlled_hosts (set): Controlled hosts. - known_hosts (set): Known hosts. - known_services (dict): Known services. - known_data (dict): Known data. - known_networks (set): Known networks. - known_blocks (dict): Known blocks. + controlled_hosts (Set[IP]): Controlled hosts. + known_hosts (Set[IP]): Known hosts. + known_services (Dict[IP, Set[Service]]): Known services. + known_data (Dict[IP, Set[Data]]): Known data. + known_networks (Set[Network]): Known networks. + known_blocks (Dict[IP, Set[IP]]): Known blocks. """ - controlled_hosts: set = field(default_factory=set, hash=True) - known_hosts: set = field(default_factory=set, hash=True) - known_services: dict = field(default_factory=dict, hash=True) - known_data: dict = field(default_factory=dict, hash=True) - known_networks: set = field(default_factory=set, hash=True) - known_blocks: dict = field(default_factory=dict, hash=True) + controlled_hosts: Set[IP] = field(default_factory=set, hash=True) + known_hosts: Set[IP] = field(default_factory=set, hash=True) + known_services: Dict[IP, Set[Service]] = field(default_factory=dict, hash=True) + known_data: Dict[IP, Set[Data]] = field(default_factory=dict, hash=True) + known_networks: Set[Network] = field(default_factory=set, hash=True) + known_blocks: Dict[IP, Set[IP]] = field(default_factory=dict, hash=True) @property - def as_graph(self)->tuple: + def as_graph(self) -> Tuple[List[int], List[int], List[Tuple[int, int]], Dict[Any, int]]: """ Build a graph representation of the game state. Returns: - tuple: (node_features, controlled, edges, node_index_map) + Tuple[List[int], List[int], List[Tuple[int, int]], Dict[Any, int]]: + (node_features, controlled, edges, node_index_map) """ node_types = {"network":0, "host":1, "service":2, "datapoint":3, "blocks": 4} graph_nodes = {} @@ -573,7 +582,7 @@ def as_graph(self)->tuple: if str(host) in netaddr.IPNetwork(str(net)): edges.append((graph_nodes[net], graph_nodes[host])) edges.append((graph_nodes[host], graph_nodes[net])) - except netaddr.core.AddrFormatError as error: + except netaddr.AddrFormatError as error: print(host, self.known_networks, self.known_hosts) print("Error:") print(error) @@ -627,35 +636,44 @@ def as_json(self) -> str: return json.dumps(ret_dict) @property - def as_dict(self)->dict: + def as_dict(self) -> Dict[str, Any]: """ Return the dictionary representation of the GameState. Returns: - dict: The game state as a dictionary. + Dict[str, Any]: The game state as a dictionary. """ - ret_dict = {"known_networks":[dataclasses.asdict(x) for x in self.known_networks], + ret_dict = { + "known_networks":[dataclasses.asdict(x) for x in self.known_networks], "known_hosts":[dataclasses.asdict(x) for x in self.known_hosts], "controlled_hosts":[dataclasses.asdict(x) for x in self.controlled_hosts], "known_services": {str(host):[dataclasses.asdict(s) for s in services] for host,services in self.known_services.items()}, "known_data":{str(host):[dataclasses.asdict(d) for d in data] for host,data in self.known_data.items()}, "known_blocks":{str(target_host):[dataclasses.asdict(blocked_host) for blocked_host in blocked_hosts] for target_host, blocked_hosts in self.known_blocks.items()} - } + } return ret_dict @classmethod - def from_dict(cls, data_dict:dict)->"GameState": + def from_dict(cls, data_dict: Dict[str, Any]) -> GameState: """ Create a GameState from a dictionary. Args: - data_dict (dict): The game state as a dictionary. + data_dict (Dict[str, Any]): The game state as a dictionary. Returns: GameState: The created GameState object. """ if "known_blocks" in data_dict: - known_blocks = {IP(target_host):{IP(blocked_host["ip"]) for blocked_host in blocked_hosts} for target_host, blocked_hosts in data_dict["known_blocks"].items()} + known_blocks = {} + for target_host, blocked_hosts in data_dict["known_blocks"].items(): + blocked_ips = set() + for blocked_host in blocked_hosts: + ip_val = blocked_host["ip"] + if isinstance(ip_val, dict): + ip_val = ip_val["ip"] + blocked_ips.add(IP(ip_val)) + known_blocks[IP(target_host)] = blocked_ips else: known_blocks = {} state = GameState( @@ -670,7 +688,7 @@ def from_dict(cls, data_dict:dict)->"GameState": return state @classmethod - def from_json(cls, json_string)->"GameState": + def from_json(cls, json_string: str) -> GameState: """ Create a GameState from a JSON string. @@ -680,29 +698,24 @@ def from_json(cls, json_string)->"GameState": Returns: GameState: The created GameState object. """ - json_data = json.loads(json_string) - state = GameState( - known_networks = {Network(x["ip"], x["mask"]) for x in json_data["known_networks"]}, - known_hosts = {IP(x["ip"]) for x in json_data["known_hosts"]}, - controlled_hosts = {IP(x["ip"]) for x in json_data["controlled_hosts"]}, - known_services = {IP(k):{Service(s["name"], s["type"], s["version"], s["is_local"]) - for s in services} for k,services in json_data["known_services"].items()}, - known_data = {IP(k):{Data(v["owner"], v["id"], v["size"], v["type"], v["content"]) for v in values} for k,values in json_data["known_data"].items()}, - known_blocks = {IP(target_host):{IP(blocked_host) for blocked_host in blocked_hosts} for target_host, blocked_hosts in json_data["known_blocks"].items()} - ) - return state + data_dict = json.loads(json_string) + return cls.from_dict(data_dict) -# Observation - given to agent after taking an action -""" -Observations are given when making a step in the environment. - - observation: current state of the environment - - reward: float value with immediate reward for last step - - end: boolean, True if the game ended. - No further interaction is possible (either terminal state or because of timeout) - - info: dict, can contain additional information about the reason for ending -""" -Observation = namedtuple("Observation", ["state", "reward", "end", "info"]) +class Observation(NamedTuple): + """ + Observations are given when making a step in the environment. + + Attributes: + state (GameState): Current state of the environment. + reward (float): Value with immediate reward for last step. + end (bool): True if the game ended. + info (Dict[str, Any]): Dictionary with additional information about the reason for ending. + """ + state: GameState + reward: float + end: bool + info: Dict[str, Any] @enum.unique class GameStatus(enum.Enum): @@ -717,7 +730,7 @@ class GameStatus(enum.Enum): FORBIDDEN = 403 @classmethod - def from_string(cls, string:str)->"GameStatus": + def from_string(cls, string: str) -> GameStatus: """ Convert a string to a GameStatus enum. @@ -740,6 +753,7 @@ def from_string(cls, string:str)->"GameStatus": return GameStatus.RESET_DONE case _: raise ValueError(f"Invalid GameStatus string: {string}") + def __repr__(self) -> str: """ Return the string representation of the GameStatus. @@ -762,7 +776,7 @@ class AgentStatus(enum.Enum): Success = "Success" Fail = "Fail" - def to_string(self)->str: + def to_string(self) -> str: """ Convert the AgentStatus enum to string. @@ -771,12 +785,12 @@ def to_string(self)->str: """ return self.value - def __eq__(self, other)->bool: + def __eq__(self, other: object) -> bool: """ Compare AgentStatus with another AgentStatus or string. Args: - other (AgentStatus or str): The object to compare. + other (object): The object to compare. Returns: bool: True if equal, False otherwise. @@ -789,7 +803,7 @@ def __eq__(self, other)->bool: return self.value == other.replace("AgentStatus.", "") return False - def __hash__(self)->int: + def __hash__(self) -> int: """ Compute the hash of the AgentStatus. @@ -800,7 +814,7 @@ def __hash__(self)->int: return hash(self.value) @classmethod - def from_string(cls, name)->"AgentStatus": + def from_string(cls, name: str) -> AgentStatus: """ Convert a string to an AgentStatus enum. @@ -820,6 +834,82 @@ def from_string(cls, name)->"AgentStatus": except KeyError: raise ValueError(f"Invalid AgentStatus: {name}") +@enum.unique +class AgentRole(str, enum.Enum): + """ + Enum representing possible roles of agents. + """ + Attacker = "Attacker" + Defender = "Defender" + Benign = "Benign" + + def __repr__(self) -> str: + """ + Return the string representation of the AgentRole. + + Returns: + str: The agent role as a string. + """ + return self.value + + def to_string(self) -> str: + """ + Convert the AgentRole enum to string. + + Returns: + str: The string representation. + """ + return self.value + + def __eq__(self, other: object) -> bool: + """ + Compare AgentRole with another AgentRole or string. + + Args: + other (object): The object to compare. + + Returns: + bool: True if equal, False otherwise. + """ + if isinstance(other, AgentRole): + return self.value == other.value + elif isinstance(other, str): + return self.value.lower() == other.lower().replace("agentrole.", "") + return False + + def __hash__(self) -> int: + """ + Compute the hash of the AgentRole. + + Returns: + int: The hash value. + """ + return hash(self.value) + + @classmethod + def from_string(cls, name: str) -> AgentRole: + """ + Convert a string to an AgentRole enum. + + Args: + name (str): The string representation. + + Returns: + AgentRole: The corresponding AgentRole. + + Raises: + ValueError: If the string does not match any AgentRole. + """ + # Clean up input string + name = name.split(".")[-1] # Remove prefix if present + + # Try case-insensitive matching + for role in cls: + if role.value.lower() == name.lower(): + return role + + raise ValueError(f"Invalid AgentRole: {name}") + @dataclass(frozen=True) class ProtocolConfig: """ @@ -829,5 +919,13 @@ class ProtocolConfig: END_OF_MESSAGE (bytes): End-of-message marker. BUFFER_SIZE (int): Buffer size for messages. """ - END_OF_MESSAGE = b"EOF" - BUFFER_SIZE = 8192 \ No newline at end of file + END_OF_MESSAGE: bytes = b"EOF" + BUFFER_SIZE: int = 8192 + +if __name__ == "__main__": + role_str = AgentRole.Attacker.to_string() + role = AgentRole.from_string(role_str) + action = Action(ActionType.JoinGame, parameters={"agent_info": {"role": role, "name": "TestAgent"}}) + print(action) + print(action.to_json()) + print(action.from_json(action.to_json())) \ No newline at end of file diff --git a/AIDojoCoordinator/utils/action_plots.r b/netsecgame/utils/action_plots.r similarity index 100% rename from AIDojoCoordinator/utils/action_plots.r rename to netsecgame/utils/action_plots.r diff --git a/AIDojoCoordinator/utils/action_plots_readme.md b/netsecgame/utils/action_plots_readme.md similarity index 100% rename from AIDojoCoordinator/utils/action_plots_readme.md rename to netsecgame/utils/action_plots_readme.md diff --git a/AIDojoCoordinator/utils/actions_parser.py b/netsecgame/utils/actions_parser.py similarity index 100% rename from AIDojoCoordinator/utils/actions_parser.py rename to netsecgame/utils/actions_parser.py diff --git a/AIDojoCoordinator/utils/aidojo_log_colorizer.py b/netsecgame/utils/aidojo_log_colorizer.py similarity index 100% rename from AIDojoCoordinator/utils/aidojo_log_colorizer.py rename to netsecgame/utils/aidojo_log_colorizer.py diff --git a/AIDojoCoordinator/utils/gamaplay_graphs.py b/netsecgame/utils/gamaplay_graphs.py similarity index 99% rename from AIDojoCoordinator/utils/gamaplay_graphs.py rename to netsecgame/utils/gamaplay_graphs.py index a2f6f80e..3f862bed 100644 --- a/AIDojoCoordinator/utils/gamaplay_graphs.py +++ b/netsecgame/utils/gamaplay_graphs.py @@ -5,7 +5,7 @@ import argparse import matplotlib.pyplot as plt -from AIDojoCoordinator.game_components import GameState, Action +from netsecgame.game_components import GameState, Action class TrajectoryGraph: def __init__(self)->None: diff --git a/AIDojoCoordinator/utils/log_parser.py b/netsecgame/utils/log_parser.py similarity index 100% rename from AIDojoCoordinator/utils/log_parser.py rename to netsecgame/utils/log_parser.py diff --git a/AIDojoCoordinator/utils/trajectory_analysis.py b/netsecgame/utils/trajectory_analysis.py similarity index 99% rename from AIDojoCoordinator/utils/trajectory_analysis.py rename to netsecgame/utils/trajectory_analysis.py index 1bed8d3d..673a1042 100644 --- a/AIDojoCoordinator/utils/trajectory_analysis.py +++ b/netsecgame/utils/trajectory_analysis.py @@ -9,7 +9,7 @@ import plotly.graph_objects as go from sklearn.preprocessing import StandardScaler -from AIDojoCoordinator.game_components import GameState, Action, ActionType +from netsecgame.game_components import GameState, Action, ActionType diff --git a/netsecgame/utils/trajectory_recorder.py b/netsecgame/utils/trajectory_recorder.py new file mode 100644 index 00000000..bf0b4231 --- /dev/null +++ b/netsecgame/utils/trajectory_recorder.py @@ -0,0 +1,82 @@ +# trajectory_recorder.py +# Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import os +import logging +from datetime import datetime +from typing import Optional, Dict, Any +from netsecgame.game_components import Action, GameState +from netsecgame.utils.utils import store_trajectories_to_jsonl + +class TrajectoryRecorder: + """ + Manages the recording and storage of agent trajectories. + """ + def __init__(self, agent_name: str, agent_role: str): + self.agent_name = agent_name + self.agent_role = agent_role + self.logger = logging.getLogger(f"TrajectoryRecorder-{agent_name}") + self._data: Dict[str, Any] = {} + self.reset() + + def reset(self) -> None: + """ + Resets the trajectory data for a new episode. + """ + self.logger.debug(f"Resetting trajectory for {self.agent_name}") + self._data = { + "trajectory": { + "states": [], + "actions": [], + "rewards": [], + }, + "end_reason": None, + "agent_role": self.agent_role, + "agent_name": self.agent_name + } + + def add_step(self, action: Action, reward: float, next_state: GameState, end_reason: Optional[str] = None) -> None: + """ + Adds a single step to the trajectory. + + Args: + action (Action): The action taken. + reward (float): The reward received. + next_state (GameState): The resulting state. + end_reason (Optional[str]): Reason for episode end, if applicable. + """ + self.logger.debug(f"Adding step to trajectory for {self.agent_name}") + # Assuming Action and GameState have .as_dict property or method as in original code + # In original code: action.as_dict, next_state.as_dict + self._data["trajectory"]["actions"].append(action.as_dict) + self._data["trajectory"]["rewards"].append(reward) + self._data["trajectory"]["states"].append(next_state.as_dict) + + if end_reason: + self._data["end_reason"] = end_reason + + def add_initial_state(self, state: GameState) -> None: + """ + Adds the initial state to the trajectory (optional, depending on how you want to track s_0). + The original code initialized trajectory with states=[agent_state.as_dict]. + """ + self._data["trajectory"]["states"].append(state.as_dict) + + def get_trajectory(self) -> Dict[str, Any]: + """ + Returns the current trajectory data. + """ + return self._data + + def save_to_file(self, location: str = "./logs/trajectories") -> None: + """ + Saves the recorded trajectory to a JSONL file. + + Args: + location (str): Directory to save the file. + """ + filename = f"{datetime.now():%Y-%m-%d}_{self.agent_name}_{self.agent_role}" + try: + store_trajectories_to_jsonl(self._data, location, filename) + self.logger.info(f"Trajectory stored in {os.path.join(location, filename)}.jsonl") + except Exception as e: + self.logger.error(f"Failed to store trajectory: {e}") diff --git a/netsecgame/utils/utils.py b/netsecgame/utils/utils.py new file mode 100644 index 00000000..e0190eee --- /dev/null +++ b/netsecgame/utils/utils.py @@ -0,0 +1,314 @@ +# Utility functions for then env and for the agents +# Author: Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz +# Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz +# --- Standard Library Imports --- +import csv +import hashlib +import json +import logging +import os +from typing import Optional + +# --- Third-Party Imports --- +import jsonlines + +# --- Local Imports --- +from netsecgame.game_components import ( + Action, + ActionType, + Data, + GameState, + IP, + Network, + Observation, + Service, +) + +def get_file_hash(filepath, hash_func='sha256', chunk_size=4096): + """ + Computes hash of a given file. + Args: + filepath (str): The path to the file to hash. + hash_func (str): The hash function to use (default is 'sha256'). + chunk_size (int): The size of each chunk to read from the file (default is 4096 bytes). + Returns: + str: The hexadecimal hash of the file. + """ + hash_algorithm = hashlib.new(hash_func) + with open(filepath, 'rb') as file: + chunk = file.read(chunk_size) + while chunk: + hash_algorithm.update(chunk) + chunk = file.read(chunk_size) + return hash_algorithm.hexdigest() + +def get_str_hash(string, hash_func='sha256'): + """ + Computes hash of a given string. + Args: + string (str): The input string to hash. + hash_func (str): The hash function to use (default is 'sha256'). + Returns: + str: The hexadecimal hash of the input string. + """ + hash_algorithm = hashlib.new(hash_func) + hash_algorithm.update(string.encode('utf-8')) + return hash_algorithm.hexdigest() + +def read_replay_buffer_from_csv(csvfile:str)->list: + """ + Function to read steps from a CSV file + and restore the objects in the replay buffer. + + expected colums in the csv: + state_t0, action_t0, reward_t1, state_t1, done_t1 + """ + raise DeprecationWarning("This function is deprecated and will be removed in future versions.") + buffer = [] + try: + with open(csvfile, 'r') as f_object: + csv_reader = csv.reader(f_object, delimiter=';') + for [s_t, a_t, r, s_t1 , done] in csv_reader: + buffer.append((GameState.from_json(s_t), Action.from_json(a_t), r, GameState.from_json(s_t1), done)) + except FileNotFoundError: + # There was no buffer + pass + return buffer + +def store_replay_buffer_in_csv(replay_buffer:list, filename:str, delimiter:str=";")->None: + """ + Function to store steps from a replay buffer in CSV file. + Expected format of replay buffer items: + (state_t0:GameState, action_t0:Action, reward_t1:float, state_t1:GameState, done_t1:bool) + """ + raise DeprecationWarning("This function is deprecated and will be removed in future versions.") + with open(filename, 'a') as f_object: + writer_object = csv.writer(f_object, delimiter=delimiter) + for (s_t, a_t, r, s_t1, done) in replay_buffer: + writer_object.writerow([s_t.as_json(), a_t.as_json(), r, s_t1.as_json(), done]) + +def state_as_ordered_string(state:GameState)->str: + ret = "" + ret += f"nets:[{','.join([str(x) for x in sorted(state.known_networks)])}]," + ret += f"hosts:[{','.join([str(x) for x in sorted(state.known_hosts)])}]," + ret += f"controlled:[{','.join([str(x) for x in sorted(state.controlled_hosts)])}]," + ret += "services:{" + for host in sorted(state.known_services.keys()): + ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_services[host])])}]" + ret += "},data:{" + for host in sorted(state.known_data.keys()): + ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_data[host])])}]" + ret += "}, blocks:{" + for host in sorted(state.known_blocks.keys()): + ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_blocks[host])])}]" + ret += "}" + return ret + +def observation_as_dict(observation: Observation) -> dict: + """ + Generates dict representation of a given Observation object. + Acts as the single source of truth for the structure. + """ + return { + 'state': observation.state.as_dict, + 'reward': observation.reward, + 'end': observation.end, + # Using dict() ensures safety if info is a namedtuple or other mapping + 'info': dict(observation.info) + } + +def observation_to_str(observation: Observation) -> str: + """ + Generates JSON string representation of a given Observation object. + Relies on observation_as_dict to define the structure. + """ + try: + # Clean JSON structure: {"state": {...}, "reward": 0, ...} + # No more escaped JSON strings inside the JSON. + return json.dumps(observation_as_dict(observation)) + except Exception as e: + logging.getLogger(__name__).error(f"Error in encoding observation '{observation}' to JSON string: {e}") + raise e + +def observation_from_dict(data: dict) -> Observation: + """ + Reconstructs an Observation object from a dictionary representation. + + Args: + data (dict): The dictionary containing observation data. + + Returns: + Observation: The reconstructed Observation namedtuple. + """ + try: + # Since we refactored serialization, 'state' is now a dictionary + state_data = data.get("state") + + # Robustness check: Ensure we have a dict before converting + if isinstance(state_data, dict): + state = GameState.from_dict(state_data) + else: + raise ValueError(f"Expected dictionary for 'state', got {type(state_data)}") + + return Observation( + state=state, + reward=float(data.get("reward", 0.0)), + end=bool(data.get("end", False)), + info=data.get("info", {}) + ) + except Exception as e: + logging.getLogger(__name__).error(f"Error in creating Observation from dict: {e}") + raise e + +def observation_from_str(json_str: str) -> Observation: + """ + Reconstructs an Observation object from a JSON string representation. + + Args: + json_str (str): The JSON string representation of the observation. + + Returns: + Observation: The reconstructed Observation namedtuple. + """ + try: + # 1. Parse the main JSON string -> returns a dict + data = json.loads(json_str) + + # 2. Pass that dict to our existing from_dict method + # This keeps the logic DRY (Don't Repeat Yourself) + return observation_from_dict(data) + + except Exception as e: + logging.getLogger(__name__).error(f"Error in creating Observation from string: {e}") + raise e + +def parse_log_content(log_content:str)->Optional[list]: + try: + logs = [] + data = json.loads(log_content) + for item in data: + ip = IP(item["source_host"]) + action_type = ActionType.from_string(item["action_type"]) + logs.append({"source_host":ip, "action_type":action_type}) + return logs + except json.JSONDecodeError as e: + logging.getLogger(__name__).error(f"Error decoding JSON: {e}") + return None + except TypeError as e: + logging.getLogger(__name__).error(f"Error decoding JSON: {e}") + return None + +def get_logging_level(debug_level): + """ + Configure logging level based on the provided debug_level string. + """ + log_levels = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL + } + + level = log_levels.get(debug_level.upper(), logging.ERROR) + return level + +def store_trajectories_to_jsonl(trajectories:list, dir:str, filename:str)->None: + """ + Store trajectories to a JSONL file. + Args: + trajectories (list): List of trajectory data to store. + dir (str): Directory where the file will be stored. + filename (str): Name of the file (without extension). + """ + # make sure the directory exists + if not os.path.exists(dir): + os.makedirs(dir) + # construct the full file name + filename = os.path.join(dir, f"{filename.rstrip('jsonl')}.jsonl") + # store the trajectories + with jsonlines.open(filename, "a") as writer: + writer.write(trajectories) + +def read_trajectories_from_jsonl(filepath:str)->list: + """ + Read trajectories from a JSONL file. + Args: + filepath (str): Path to the JSONL file. + Returns: + list: List of trajectories read from the file. + """ + raise NotImplementedError("This function is not yet implemented.") + +def generate_valid_actions(state: GameState, include_blocks=False)->list: + """Function that generates a list of all valid actions in a given GameState + Args: + state (GameState): The current game state. + include_blocks (bool): Whether to include BlockIP actions. Defaults to False. + Returns: + list: A list of valid Action objects. + """ + valid_actions = set() + def is_fw_blocked(state, src_ip, dst_ip)->bool: + blocked = False + try: + blocked = dst_ip in state.known_blocks[src_ip] + except KeyError: + pass #this src ip has no known blocks + return blocked + + for source_host in state.controlled_hosts: + #Network Scans + for network in state.known_networks: + # TODO ADD neighbouring networks + valid_actions.add(Action(ActionType.ScanNetwork, parameters={"target_network": network, "source_host": source_host,})) + + # Service Scans + for blocked_host in state.known_hosts: + if not is_fw_blocked(state, source_host, blocked_host): + valid_actions.add(Action(ActionType.FindServices, parameters={"target_host": blocked_host, "source_host": source_host,})) + + # Service Exploits + for blocked_host, service_list in state.known_services.items(): + if not is_fw_blocked(state, source_host,blocked_host): + for service in service_list: + valid_actions.add(Action(ActionType.ExploitService, parameters={"target_host": blocked_host,"target_service": service,"source_host": source_host,})) + # Data Scans + for blocked_host in state.controlled_hosts: + if not is_fw_blocked(state, source_host,blocked_host): + valid_actions.add(Action(ActionType.FindData, parameters={"target_host": blocked_host, "source_host": blocked_host})) + + # Data Exfiltration + for source_host, data_list in state.known_data.items(): + for data in data_list: + for trg_host in state.controlled_hosts: + if trg_host != source_host: + if not is_fw_blocked(state, source_host,trg_host): + valid_actions.add(Action(ActionType.ExfiltrateData, parameters={"target_host": trg_host, "source_host": source_host, "data": data})) + + # BlockIP + if include_blocks: + for source_host in state.controlled_hosts: + for target_host in state.controlled_hosts: + if not is_fw_blocked(state, source_host,target_host): + for blocked_ip in state.known_hosts: + valid_actions.add(Action(ActionType.BlockIP, {"target_host":target_host, "source_host":source_host, "blocked_host":blocked_ip})) + return list(valid_actions) + +if __name__ == "__main__": + state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, + known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, + known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}}, + known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")}, + IP("192.168.1.2"):{Data("McGiver", "data2")}}) + + print(state_as_ordered_string(state)) + obs = Observation(state=state, reward=10.0, end=False, info={"info1":"value1"}) + obs_str = observation_to_str(obs) + print(obs_str) + obs_restored = observation_from_str(obs_str) + print(obs_restored) + print(observation_as_dict(obs_restored)) + actions = generate_valid_actions(state, include_blocks=True) + for action in actions: + print(action) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 318382db..1761c00a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,55 +3,47 @@ requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "AIDojoGameCoordinator" -version = "0.1.0" +name = "netsecgame" +dynamic = ["version"] description = "A package for coordinating AI-driven network simulation games." readme = "README.md" -license = { file = "LICENSE" } +license = "GPL-3.0-or-later" +license-files = ["LICENSE"] authors = [ { name = "Ondrej Lukas", email = "ondrej.lukas@aic.fel.cvut.cz" }, { name = "Sebastian Garcia", email = "sebastian.garcia@agents.fel.cvut.cz" }, { name = "Maria Rigaki", email = "maria.rigaki@aic.fel.cvut.cz" } ] +# light-weight version (allows running and development of agents) dependencies = [ - "aiohttp==3.11.8", - "attrs==23.2.0", - "beartype==0.19.0", - "cachetools==5.5.0", - "casefy==0.1.7", - "cyst==0.3.4", - "dictionaries==0.0.2", - "Faker==23.2.1", - "Jinja2==3.1.4", - "jsonlines==4.0.0", - "jsonpickle==3.3.0", - "kaleido==0.2.1", - "MarkupSafe==3.0.2", - "matplotlib==3.9.1", + "jsonlines>=4.0.0", "netaddr==0.9.0", - "networkx==3.4.2", - "numpy==1.26.4", - "pandas==2.2.2", - "plotly==5.22.0", - "pyserde==0.21.0", - "python-dateutil==2.8.2", - "PyYAML==6.0.1", - "redis==3.5.3", - "requests==2.32.3", - "scikit-learn==1.5.1", - "scipy==1.14.0", - "tenacity==8.5.0", - "typing-inspect==0.9.0", - "typing_extensions==4.12.2", - "cyst-core>=0.5.0" +] +classifiers = [ + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ] requires-python = ">=3.12" [project.optional-dependencies] + +# dependencies allowing to run the game server and the simulation +server = [ + "aiohttp>=3.11", + "cyst-core>=0.5.0", + "Faker>=23.2", + "numpy>=1.26", + "PyYAML>=6.0", + "requests>=2.32", + "pyserde==0.21.0" +] + dev = [ + "netsecgame[server]", "pytest", "ruff", - "pytest-asyncio" + "pytest-asyncio", + "twine" ] docs = [ @@ -61,6 +53,9 @@ docs = [ "pymdown-extensions" ] +[tool.setuptools.dynamic] +version = { attr = "netsecgame.__version__" } + [project.urls] Homepage = "https://github.com/stratosphereips/NetSecGame" Repository = "https://github.com/stratosphereips/NetSecGame" @@ -69,7 +64,19 @@ Issues = "https://github.com/stratosphereips/NetSecGame/issues" [tool.setuptools.packages.find] where = ["."] -exclude = ["tests*"] +include = ["netsecgame*"] +exclude = [ + "tests*", + "notebooks*", + "site*", + "docs*", + "logs*", + "mkdocs.yml", + "Dockerfile", + "NetSecGameAgents*" +] + + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/tests/OLD_test_actions.py b/tests/OLD_test_actions.py index 443c80bf..281b674d 100644 --- a/tests/OLD_test_actions.py +++ b/tests/OLD_test_actions.py @@ -5,8 +5,8 @@ import sys from os import path sys.path.append( path.dirname(path.dirname( path.abspath(__file__) ) )) -from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment -import AIDojoCoordinator.game_components as components +from netsecgame.worlds.network_security_game import NetworkSecurityEnvironment +import netsecgame.game_components as components import pytest # Fixture are used to hold the current state and the environment diff --git a/tests/components/test_action.py b/tests/components/test_action.py index b1c75526..5fa036b0 100644 --- a/tests/components/test_action.py +++ b/tests/components/test_action.py @@ -1,7 +1,8 @@ # Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import json -from AIDojoCoordinator.game_components import Action, ActionType, IP, Network, Data, Service, AgentInfo +import pytest +from netsecgame.game_components import Action, ActionType, IP, Network, Data, Service, AgentInfo class TestComponentActionType: """ @@ -447,4 +448,46 @@ def test_action_to_dict_quit_game(self): new_action = Action.from_dict(action_dict) assert action == new_action assert action_dict["action_type"] == str(action.type) - assert len(action_dict["parameters"]) == 0 \ No newline at end of file + assert len(action_dict["parameters"]) == 0 + + def test_action_type_eq_unsupported(self): + """Test ActionType equality with unsupported type""" + assert (ActionType.FindData == 123) is False + + def test_action_type_from_string_invalid(self): + """Test ActionType.from_string with invalid string""" + with pytest.raises(ValueError): + ActionType.from_string("InvalidAction") + + def test_action_eq_unsupported(self): + """Test Action equality with unsupported type""" + action = Action(action_type=ActionType.FindData) + assert (action == "some_string") is False + + def test_action_from_dict_invalid_parameter(self): + """Test Action.from_dict with invalid parameter key""" + data = { + "action_type": "ActionType.FindData", + "parameters": {"unknown_param": "value"} + } + with pytest.raises(ValueError): + Action.from_dict(data) + + def test_action_to_dict_bool_parameter(self): + """Test handling of boolean parameters in as_dict""" + action = Action( + action_type=ActionType.ResetGame, + parameters={"request_trajectory": True} + ) + d = action.as_dict + assert d["parameters"]["request_trajectory"] is True + + def test_action_to_dict_str_parameter(self): + """Test handling of string parameters in as_dict""" + # Inject a parameter that is just a string (not a dataclass) + # We need a new ActionType or reuse one that accepts arbitrary params? + # The existing code mainly expects specific params. + # But we can force it for testing as_dict logic. + action = Action(ActionType.FindData, parameters={"simple_param": "simple_value"}) + d = action.as_dict + assert d["parameters"]["simple_param"] == "simple_value" \ No newline at end of file diff --git a/tests/components/test_data.py b/tests/components/test_data.py index b3946f94..2ed0498b 100644 --- a/tests/components/test_data.py +++ b/tests/components/test_data.py @@ -2,7 +2,7 @@ # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import pytest import dataclasses -from AIDojoCoordinator.game_components import Data +from netsecgame.game_components import Data @pytest.fixture def sample_data_minimal(): diff --git a/tests/components/test_enums.py b/tests/components/test_enums.py new file mode 100644 index 00000000..688057ed --- /dev/null +++ b/tests/components/test_enums.py @@ -0,0 +1,121 @@ +import pytest +from netsecgame.game_components import GameStatus, AgentStatus, Observation, GameState, ProtocolConfig, AgentRole +import json + +class TestGameStatus: + def test_from_string_valid(self): + """Test valid from_string conversions""" + assert GameStatus.from_string("GameStatus.OK") == GameStatus.OK + assert GameStatus.from_string("GameStatus.CREATED") == GameStatus.CREATED + assert GameStatus.from_string("GameStatus.RESET_DONE") == GameStatus.RESET_DONE + assert GameStatus.from_string("GameStatus.BAD_REQUEST") == GameStatus.BAD_REQUEST + assert GameStatus.from_string("GameStatus.FORBIDDEN") == GameStatus.FORBIDDEN + + def test_from_string_invalid(self): + """Test invalid from_string conversion""" + with pytest.raises(ValueError): + GameStatus.from_string("GameStatus.INVALID") + + def test_repr(self): + """Test string representation""" + assert str(GameStatus.OK) == "GameStatus.OK" + assert repr(GameStatus.OK) == "GameStatus.OK" + +class TestAgentStatus: + def test_to_string(self): + """Test to_string method""" + assert AgentStatus.Playing.to_string() == "Playing" + assert AgentStatus.Success.to_string() == "Success" + + def test_eq_string(self): + """Test equality with string""" + assert AgentStatus.Playing == "Playing" + assert AgentStatus.Playing == "AgentStatus.Playing" + assert (AgentStatus.Playing == "Other") is False + + def test_eq_self(self): + """Test equality with self""" + assert AgentStatus.Playing == AgentStatus.Playing + assert (AgentStatus.Playing == AgentStatus.Success) is False + + def test_eq_other(self): + """Test equality with other types""" + assert (AgentStatus.Playing == 123) is False + + def test_hash(self): + """Test hash consistency""" + assert hash(AgentStatus.Playing) == hash(AgentStatus.Playing.value) + + def test_from_string(self): + """Test from_string method""" + assert AgentStatus.from_string("AgentStatus.Playing") == AgentStatus.Playing + assert AgentStatus.from_string("Playing") == AgentStatus.Playing + + with pytest.raises(ValueError): + AgentStatus.from_string("Invalid") + +class TestObservation: + def test_creation(self): + """Test creation of Observation named tuple""" + state = GameState() + obs = Observation(state=state, reward=1.0, end=False, info={}) + + assert obs.state == state + assert obs.reward == 1.0 + assert obs.end is False + assert obs.info == {} + +class TestProtocolConfig: + def test_constants(self): + """Test protocol constants""" + conf = ProtocolConfig() + assert conf.END_OF_MESSAGE == b"EOF" + +class TestAgentRole: + def test_values(self): + """Test enum values""" + assert AgentRole.Attacker.value == "Attacker" + assert AgentRole.Defender.value == "Defender" + assert AgentRole.Benign.value == "Benign" + + def test_to_string(self): + """Test to_string method""" + assert AgentRole.Attacker.to_string() == "Attacker" + assert AgentRole.Defender.to_string() == "Defender" + + def test_from_string(self): + """Test from_string method""" + assert AgentRole.from_string("Attacker") == AgentRole.Attacker + assert AgentRole.from_string("attacker") == AgentRole.Attacker + assert AgentRole.from_string("AgentRole.Attacker") == AgentRole.Attacker + + with pytest.raises(ValueError): + AgentRole.from_string("InvalidRole") + + def test_equality(self): + """Test equality comparison""" + # Compare with Enum + assert AgentRole.Attacker == AgentRole.Attacker + assert AgentRole.Attacker != AgentRole.Defender + + # Compare with String + assert AgentRole.Attacker == "Attacker" + assert AgentRole.Attacker == "attacker" # Case insensitive + assert AgentRole.Attacker != "Defender" + + def test_hashability(self): + """Test usage as dictionary key""" + d = {AgentRole.Attacker: 1, AgentRole.Defender: 2} + assert d[AgentRole.Attacker] == 1 + assert d["Attacker"] == 1 # Matches string equivalent + assert d[AgentRole.Defender] == 2 + + def test_json_serialization(self): + """Test native JSON serialization""" + data = {"role": AgentRole.Attacker} + json_str = json.dumps(data) + assert json_str == '{"role": "Attacker"}' + + # Round trip + decoded = json.loads(json_str) + assert decoded["role"] == "Attacker" diff --git a/tests/components/test_game_state.py b/tests/components/test_game_state.py index 770efb43..6eb748a9 100644 --- a/tests/components/test_game_state.py +++ b/tests/components/test_game_state.py @@ -2,7 +2,7 @@ # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import json import pytest -from AIDojoCoordinator.game_components import GameState, IP, Network, Data, Service +from netsecgame.game_components import GameState, IP, Network, Data, Service # pytest fixtures for creating sample objects @pytest.fixture @@ -300,4 +300,85 @@ def test_game_state_from_dict(sample_ip, sample_ip2, sample_network, sample_serv game_dict = game_state.as_dict deserialized_state = GameState.from_dict(game_dict) assert game_state is not deserialized_state - assert game_state == deserialized_state \ No newline at end of file + assert game_state == deserialized_state + +def test_game_state_as_graph(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """Test as_graph method""" + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip: {sample_service}}, + known_data={sample_ip: {sample_data}}, + known_networks={sample_network} + ) + + node_features, controlled, edges, node_index_map = game_state.as_graph + + # Check basic structure + assert isinstance(node_features, list) + assert isinstance(controlled, list) + assert isinstance(edges, list) + assert isinstance(node_index_map, dict) + + # Check node types mapping: network:0, host:1, service:2, datapoint:3, blocks:4 + # We expect: 1 network, 2 hosts, 1 service, 1 data point = 5 nodes + assert len(node_features) == 5 + assert len(controlled) == 5 + assert len(node_index_map) == 5 + + # Check specific nodes are present + assert sample_network in node_index_map.values() + assert sample_ip in node_index_map.values() + assert sample_ip2 in node_index_map.values() + assert sample_service in node_index_map.values() + assert sample_data in node_index_map.values() + + # Invert map for index lookup + obj_to_idx = {v: k for k, v in node_index_map.items()} + + # Check controlled status + ip1_idx = obj_to_idx[sample_ip] + ip2_idx = obj_to_idx[sample_ip2] + assert controlled[ip1_idx] == 1 + assert controlled[ip2_idx] == 0 + + # Check edges + # Host IP1 should be connected to Network (192.168.1.1 in 192.168.1.0/24) + # Host IP2 should be connected to Network + # Service on IP1 should be connected to IP1 + # Data on IP1 should be connected to IP1 + + net_idx = obj_to_idx[sample_network] + svc_idx = obj_to_idx[sample_service] + data_idx = obj_to_idx[sample_data] + + # Check edge existence (undirected, so double entries) + edge_set = set(edges) + assert (net_idx, ip1_idx) in edge_set + assert (ip1_idx, net_idx) in edge_set + assert (net_idx, ip2_idx) in edge_set + assert (ip1_idx, svc_idx) in edge_set + assert (ip1_idx, data_idx) in edge_set + +def test_game_state_known_blocks(sample_ip, sample_ip2): + """Test known_blocks handling""" + # Create state with blocks + blocks = {sample_ip: {sample_ip2}} + game_state = GameState(known_blocks=blocks) + + assert game_state.known_blocks == blocks + + # Test to dict + d = game_state.as_dict + assert "known_blocks" in d + # Expect: {'192.168.1.1': [{'ip': '192.168.1.2'}]} + assert d["known_blocks"][str(sample_ip)][0]["ip"] == str(sample_ip2) + + # Test from dict + new_state = GameState.from_dict(d) + assert new_state == game_state + + # Test to/from json + j = game_state.as_json() + new_state_json = GameState.from_json(j) + assert new_state_json == game_state \ No newline at end of file diff --git a/tests/components/test_ip.py b/tests/components/test_ip.py index 58fc8870..b72dda5f 100644 --- a/tests/components/test_ip.py +++ b/tests/components/test_ip.py @@ -2,7 +2,7 @@ # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import pytest import dataclasses -from AIDojoCoordinator.game_components import IP +from netsecgame.game_components import IP # Pytest fixtures for creating sample IP objects @pytest.fixture @@ -85,4 +85,10 @@ def test_ip_hash(sample_private_ip1, sample_private_ip1_copy): """Test that the hash of two IP objects with the same IP is equal""" ip_1, _ = sample_private_ip1 ip_2, _ = sample_private_ip1_copy - assert hash(ip_1) == hash(ip_2) \ No newline at end of file + assert hash(ip_1) == hash(ip_2) + +def test_ip_eq_other_type(sample_private_ip1): + """Test equality with non-IP object""" + ip_1, _ = sample_private_ip1 + assert ip_1.__eq__("some_string") is NotImplemented + assert (ip_1 == "some_string") is False # Python fallback \ No newline at end of file diff --git a/tests/components/test_network.py b/tests/components/test_network.py index 5042e3c8..d12cfdc1 100644 --- a/tests/components/test_network.py +++ b/tests/components/test_network.py @@ -2,7 +2,7 @@ # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import pytest import dataclasses -from AIDojoCoordinator.game_components import Network +from netsecgame.game_components import Network # Pytest fixture for creating a sample Network object @pytest.fixture @@ -80,4 +80,21 @@ def test_net_from_dict(sample_private_network1): assert net.ip == "192.168.1.0" assert net.mask == 24 assert net == sample_private_network1 - assert net is not sample_private_network1 \ No newline at end of file + assert net is not sample_private_network1 + +def test_net_less_than(sample_private_network1, sample_private_network2): + """Test __lt__ operator""" + # 192.168.1.0/24 < 192.168.2.0/24 + assert sample_private_network1 < sample_private_network2 + assert not (sample_private_network2 < sample_private_network1) + +def test_net_less_than_equal(sample_private_network1, sample_private_network2, sample_private_network1_copy): + """Test __le__ operator""" + assert sample_private_network1 <= sample_private_network2 + assert sample_private_network1 <= sample_private_network1_copy + assert not (sample_private_network2 <= sample_private_network1) + +def test_net_greater_than(sample_private_network1, sample_private_network2): + """Test __gt__ operator""" + assert sample_private_network2 > sample_private_network1 + assert not (sample_private_network1 > sample_private_network2) \ No newline at end of file diff --git a/tests/components/test_service.py b/tests/components/test_service.py index 69f87e12..2c3cdd2d 100644 --- a/tests/components/test_service.py +++ b/tests/components/test_service.py @@ -2,7 +2,7 @@ # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import pytest import dataclasses -from AIDojoCoordinator.game_components import Service +from netsecgame.game_components import Service # Fixtures for Service objects @pytest.fixture diff --git a/tests/coordinator/test_agent_server.py b/tests/game/test_agent_server.py similarity index 93% rename from tests/coordinator/test_agent_server.py rename to tests/game/test_agent_server.py index 99c2f651..233e3d3d 100644 --- a/tests/coordinator/test_agent_server.py +++ b/tests/game/test_agent_server.py @@ -3,8 +3,8 @@ import pytest from unittest.mock import AsyncMock, MagicMock from contextlib import suppress -from AIDojoCoordinator.coordinator import AgentServer -from AIDojoCoordinator.game_components import Action, ActionType, ProtocolConfig +from netsecgame.game.coordinator import AgentServer +from netsecgame.game_components import Action, ActionType, ProtocolConfig # ----------------------- # Fixtures @@ -16,7 +16,8 @@ def mock_writer(): writer.get_extra_info = MagicMock(return_value=('127.0.0.1', 12345)) # ✅ Sync method writer.write = MagicMock() # ✅ Sync method writer.drain = AsyncMock() # ✅ Async method - writer.close = AsyncMock() # ✅ Async method + writer.close = MagicMock() # ✅ Sync method, wait_closed is separate + writer.wait_closed = AsyncMock() # ✅ Async method return writer @pytest.fixture @@ -54,7 +55,8 @@ def _make(ip: str, port: int): writer.get_extra_info = MagicMock(return_value=(ip, port)) # get_extra_info is sync writer.write = MagicMock() # write is sync writer.drain = AsyncMock() # drain is async - writer.close = AsyncMock() # close is async + writer.close = MagicMock() # close is sync + writer.wait_closed = AsyncMock() # wait_closed is async return writer return _make @@ -268,7 +270,9 @@ async def test_answer_queue_response_is_sent_to_agent(agent_server, mock_writer) async def test_cancelled_error_cleanup(agent_server, mock_writer): peername = ('127.0.0.1', 12345) mock_writer.get_extra_info = MagicMock(return_value=peername) - mock_writer.close = AsyncMock() + # mock_writer comes with correct mocks from fixture, but let's ensure wait_closed is tracked + # close is already MagicMock from fixture update, but if we want to be explicit: + mock_writer.close = MagicMock() mock_writer.wait_closed = AsyncMock() reader = AsyncMock() reader.read = AsyncMock(side_effect=asyncio.CancelledError()) diff --git a/tests/coordinator/test_coordinator_core.py b/tests/game/test_coordinator_core.py similarity index 62% rename from tests/coordinator/test_coordinator_core.py rename to tests/game/test_coordinator_core.py index 0853693a..9058a6a0 100644 --- a/tests/coordinator/test_coordinator_core.py +++ b/tests/game/test_coordinator_core.py @@ -5,8 +5,9 @@ from unittest.mock import AsyncMock, MagicMock, patch from types import SimpleNamespace -from AIDojoCoordinator.coordinator import GameCoordinator -from AIDojoCoordinator.game_components import ActionType, Action, AgentStatus, GameState, Observation, GameStatus +from netsecgame.game.coordinator import GameCoordinator +from netsecgame.game_components import ActionType, Action, AgentStatus, GameState, Observation, GameStatus, AgentRole +from netsecgame.game.coordinator import convert_msg_dict_to_json # ----------------------- # Fixtures @@ -44,7 +45,6 @@ def gc_with_test_config(test_config_file_path): game_port=9999, service_host=None, # force local config loading service_port=0, - allowed_roles=["Attacker", "Defender", "Benign"], task_config_file=test_config_file_path, ) @@ -98,15 +98,15 @@ def _make(ip: str, port: int): @pytest.mark.asyncio async def test_load_initialization_objects_loads_config(gc_with_test_config): - """Test that loading initialization objects sets up config and cyst objects.""" - gc_with_test_config._load_initialization_objects() - assert gc_with_test_config._cyst_objects is not None - assert hasattr(gc_with_test_config, "_CONFIG_FILE_HASH") + """Test that loading initialization objects sets up config using manager.""" + await gc_with_test_config.config_manager.load() + assert gc_with_test_config.config_manager.get_cyst_objects() is not None + assert gc_with_test_config.config_manager.get_config_hash() is not None def test_convert_msg_dict_to_json_success(gc_with_test_config): """Test that convert_msg_dict_to_json correctly serializes a dictionary.""" msg = {"foo": "bar"} - json_str = gc_with_test_config.convert_msg_dict_to_json(msg) + json_str = convert_msg_dict_to_json(msg) assert json_str == '{"foo": "bar"}' @@ -116,64 +116,84 @@ class Unserializable: pass with pytest.raises(TypeError): - gc_with_test_config.convert_msg_dict_to_json({"bad": Unserializable()}) + convert_msg_dict_to_json({"bad": Unserializable()}) @pytest.mark.asyncio async def test_create_agent_queue_adds_new_queue(gc_with_test_config): - """Test that create_agent_queue adds a new queue for the agent.""" - agent = ("127.0.0.1", 12345) - await gc_with_test_config.create_agent_queue(agent) - assert agent in gc_with_test_config._agent_response_queues - assert isinstance(gc_with_test_config._agent_response_queues[agent], asyncio.Queue) + """Test that create_agent_queue adds a new queue for an unknown agent.""" + addr = ("127.0.0.1", 12345) + await gc_with_test_config.create_agent_queue(addr) + assert addr in gc_with_test_config._agent_response_queues + assert isinstance(gc_with_test_config._agent_response_queues[addr], asyncio.Queue) @pytest.mark.asyncio async def test_create_agent_queue_idempotent(gc_with_test_config): - """Test that create_agent_queue does not create a new queue if it already exists.""" - agent = ("127.0.0.1", 12345) - await gc_with_test_config.create_agent_queue(agent) - q1 = gc_with_test_config._agent_response_queues[agent] - await gc_with_test_config.create_agent_queue(agent) - q2 = gc_with_test_config._agent_response_queues[agent] - assert q1 is q2 + """Test that create_agent_queue doesn't recreate existing queues.""" + addr = ("127.0.0.1", 12345) + await gc_with_test_config.create_agent_queue(addr) + first_queue = gc_with_test_config._agent_response_queues[addr] + + await gc_with_test_config.create_agent_queue(addr) + second_queue = gc_with_test_config._agent_response_queues[addr] + + assert first_queue is second_queue -def test_load_initialization_objects(gc_with_test_config): - """Test that _load_initialization_objects initializes config and cyst objects.""" - gc_with_test_config._load_initialization_objects() - assert gc_with_test_config._cyst_objects is not None - assert hasattr(gc_with_test_config, "_CONFIG_FILE_HASH") +@pytest.mark.asyncio +async def test_load_initialization_objects(gc_with_test_config): + """Test that config_manager.load initializes config and cyst objects.""" + await gc_with_test_config.config_manager.load() + # Check that cyst objects are loaded via manager + assert gc_with_test_config.config_manager.get_cyst_objects() is not None + # Check that hash is set + assert gc_with_test_config.config_manager.get_config_hash() is not None -def test_get_starting_position_per_role(gc_with_test_config): - """Test that _get_starting_position_per_role returns positions for all roles.""" - gc_with_test_config._load_initialization_objects() - positions = gc_with_test_config._get_starting_position_per_role() - assert set(positions.keys()) == set(gc_with_test_config.ALLOWED_ROLES) +@pytest.mark.asyncio +async def test_start_tasks_initializes_config(gc_with_test_config): + """Test that start_tasks initializes configuration attributes via manager.""" + # We can't easily run full start_tasks because it starts a server loop. + # But we can verify that config loading logic works if we extract it or partial mock. + # Alternatively, we can test that calling config_manager.load() and then accessing properties works. + # Or, looking at previous tests, they tested the helper private methods. + # Now we should test the properties on config_manager directly OR verify they are set on GC after load. + + await gc_with_test_config.config_manager.load() + # Manually populate like start_tasks does to verify logic correctness (or assume start_tasks does it) + # Since start_tasks is the only place calling these, we might want to test the config_manager methods instead. + + positions = gc_with_test_config.config_manager.get_all_starting_positions() + assert "Attacker" in positions + assert "Defender" in positions -def test_get_goal_description_per_role(gc_with_test_config): - """Test that _get_goal_description_per_role returns descriptions for all roles.""" - gc_with_test_config._load_initialization_objects() - desc = gc_with_test_config._get_goal_description_per_role() - assert set(desc.keys()) == set(gc_with_test_config.ALLOWED_ROLES) +@pytest.mark.asyncio +async def test_goal_descriptions_loaded(gc_with_test_config): + """Test that goal descriptions are retrievable via config manager.""" + await gc_with_test_config.config_manager.load() + desc = gc_with_test_config.config_manager.get_all_goal_descriptions() + assert "Attacker" in desc + assert "Defender" in desc -def test_get_win_condition_per_role(gc_with_test_config): - """Test that _get_win_condition_per_role returns win conditions for all roles.""" - gc_with_test_config._load_initialization_objects() - win = gc_with_test_config._get_win_condition_per_role() - assert set(win.keys()) == set(gc_with_test_config.ALLOWED_ROLES) +@pytest.mark.asyncio +async def test_win_conditions_loaded(gc_with_test_config): + """Test that win conditions are retrievable via config manager.""" + await gc_with_test_config.config_manager.load() + win = gc_with_test_config.config_manager.get_all_win_conditions() + assert "Attacker" in win + assert "Defender" in win -def test_get_max_steps_per_role(gc_with_test_config): - """Test that _get_max_steps_per_role returns max steps for all roles.""" - gc_with_test_config._load_initialization_objects() - steps = gc_with_test_config._get_max_steps_per_role() +@pytest.mark.asyncio +async def test_max_steps_loaded(gc_with_test_config): + """Test that max steps are retrievable via config manager.""" + await gc_with_test_config.config_manager.load() + steps = gc_with_test_config.config_manager.get_all_max_steps() assert isinstance(steps, dict) - # values can be int or None - assert all(isinstance(v, int) or v is None for v in steps.values()) + assert "Attacker" in steps @pytest.mark.asyncio @@ -377,4 +397,87 @@ async def test_process_game_action_ongoing_episode(initialized_coordinator, empt assert '"status": "' + str(GameStatus.OK) + '"' in msg_json assert '"reward": 0' in msg_json assert '"end": false' in msg_json - assert '"info": {}' in msg_json \ No newline at end of file + assert '"info": {}' in msg_json + +# ----------------------- +# New tests for refactored methods (_parse_action, _dispatch_action, run_game) +# ----------------------- +class TestCoordinatorRefactoredMethods: + @pytest.fixture + def mock_coordinator_core(self): + # Create a mock coordinator slightly different from integration fixtures to purely test logic + coord = MagicMock(spec=GameCoordinator) + coord.logger = MagicMock() + coord._agent_action_queue = AsyncMock() + coord.shutdown_flag = MagicMock() + # Side effect to stop loop after one iteration + coord.shutdown_flag.is_set.side_effect = [False, True] + + # Bind refactored methods + coord._parse_action_message = GameCoordinator._parse_action_message.__get__(coord) + coord._dispatch_action = GameCoordinator._dispatch_action.__get__(coord) + coord.run_game = GameCoordinator.run_game.__get__(coord) + + # Set __name__ for the mocked handlers so assert .__name__ works + coord._process_join_game_action.__name__ = "_process_join_game_action" + coord._process_quit_game_action.__name__ = "_process_quit_game_action" + coord._process_reset_game_action.__name__ = "_process_reset_game_action" + coord._process_game_action.__name__ = "_process_game_action" + + return coord + + def test_parse_action_message_valid(self, mock_coordinator_core): + """New test for refactored method: _parse_action_message with valid input.""" + valid_json = '{"action_type": "ActionType.JoinGame", "parameters": {"agent_info": {"name": "TestAgent", "role": "Attacker"}}}' + agent_addr = ("127.0.0.1", 12345) + + action = mock_coordinator_core._parse_action_message(agent_addr, valid_json) + + assert action is not None + assert action.type == ActionType.JoinGame + assert action.parameters["agent_info"].role == AgentRole.Attacker + + def test_parse_action_message_invalid(self, mock_coordinator_core): + """New test for refactored method: _parse_action_message with invalid input.""" + invalid_json = '{"invalid": "json"}' + agent_addr = ("127.0.0.1", 12345) + + action = mock_coordinator_core._parse_action_message(agent_addr, invalid_json) + + assert action is None + mock_coordinator_core.logger.error.assert_called() + # Verify agent address is in the error log + args, _ = mock_coordinator_core.logger.error.call_args + assert str(agent_addr) in args[0] + + def test_dispatch_action(self, mock_coordinator_core): + """New test for refactored method: _dispatch_action routing.""" + action = Action(ActionType.ScanNetwork, parameters={}) + agent_addr = ("127.0.0.1", 12345) + + mock_coordinator_core._dispatch_action(agent_addr, action) + + mock_coordinator_core._spawn_task.assert_called_once() + args = mock_coordinator_core._spawn_task.call_args[0] + # Should route to _process_game_action for ScanNetwork + assert args[0].__name__ == "_process_game_action" + + @pytest.mark.asyncio + async def test_run_game_flow(self, mock_coordinator_core): + """New test for refactored method: run_game flow (parse -> dispatch).""" + agent_addr = ("127.0.0.1", 12345) + valid_json = '{"action_type": "ActionType.ScanNetwork", "parameters": {}}' + + # Setup queue + mock_coordinator_core._agent_action_queue.get.return_value = (agent_addr, valid_json) + + with patch.object(mock_coordinator_core, '_parse_action_message') as mock_parse, \ + patch.object(mock_coordinator_core, '_dispatch_action') as mock_dispatch: + + mock_action = Action(ActionType.ScanNetwork, {}) + mock_parse.return_value = mock_action + + await mock_coordinator_core.run_game() + + mock_parse.assert_called_once_with(agent_addr, valid_json) + mock_dispatch.assert_called_once_with(agent_addr, mock_action) \ No newline at end of file diff --git a/tests/coordinator/test_global_defender.py b/tests/game/test_global_defender.py similarity index 92% rename from tests/coordinator/test_global_defender.py rename to tests/game/test_global_defender.py index e47eb233..32234b0c 100644 --- a/tests/coordinator/test_global_defender.py +++ b/tests/game/test_global_defender.py @@ -1,6 +1,6 @@ import pytest -from AIDojoCoordinator.game_components import ActionType, Action -from AIDojoCoordinator.global_defender import GlobalDefender +from netsecgame.game_components import ActionType, Action +from netsecgame.game.global_defender import GlobalDefender from unittest.mock import patch @pytest.fixture @@ -56,7 +56,7 @@ def test_mock_stochastic_probabilities(defender, episode_actions): """Test stochastic function is only called when thresholds are crossed.""" action = Action(ActionType.ScanNetwork, {}) episode_actions += [{"action_type": str(ActionType.ScanNetwork)}] * 4 # Exceed threshold - - with patch("AIDojoCoordinator.global_defender.random", return_value=0.01): # Force detection probability + + with patch("netsecgame.game.global_defender.random", return_value=0.01): # Force detection probability result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5) assert result # Should be True since we forced a low probability value \ No newline at end of file diff --git a/tests/manual/three_nets/manual_test_three_net_scenario.py b/tests/manual/three_nets/manual_test_three_net_scenario.py index d1bec3ea..114495fb 100644 --- a/tests/manual/three_nets/manual_test_three_net_scenario.py +++ b/tests/manual/three_nets/manual_test_three_net_scenario.py @@ -5,7 +5,7 @@ PATH = path.dirname( path.dirname( path.dirname( path.dirname( path.abspath(__file__) ) ) )) sys.path.append(path.dirname( path.dirname( path.dirname( path.dirname( path.abspath(__file__) ) ) ))) from NetSecGameAgents.agents import base_agent -from AIDojoCoordinator.game_components import Action, ActionType, IP, Network, Service, Data +from netsecgame.game_components import Action, ActionType, IP, Network, Service, Data if __name__ == "__main__": diff --git a/tests/utils/test_trajectory_recorder.py b/tests/utils/test_trajectory_recorder.py new file mode 100644 index 00000000..0e80aefb --- /dev/null +++ b/tests/utils/test_trajectory_recorder.py @@ -0,0 +1,88 @@ +import pytest +from unittest.mock import patch +from netsecgame.utils.trajectory_recorder import TrajectoryRecorder +from netsecgame.game_components import Action, ActionType, GameState + +# Mock objects needed for tests +@pytest.fixture +def mock_action(): + return Action(ActionType.ScanNetwork, parameters={"target": "10.0.0.1"}) + +@pytest.fixture +def mock_gamestate(): + # Minimal GameState for testing + return GameState( + controlled_hosts=set(), + known_hosts=set(), + known_services={}, + known_data={}, + known_networks=set() + ) + +@pytest.fixture +def recorder(): + return TrajectoryRecorder(agent_name="test_agent", agent_role="Attacker") + +def test_initialization(recorder): + assert recorder.agent_name == "test_agent" + assert recorder.agent_role == "Attacker" + data = recorder.get_trajectory() + assert data["agent_name"] == "test_agent" + assert data["agent_role"] == "Attacker" + assert data["trajectory"]["states"] == [] + assert data["trajectory"]["actions"] == [] + assert data["trajectory"]["rewards"] == [] + assert data["end_reason"] is None + +def test_add_initial_state(recorder, mock_gamestate): + recorder.add_initial_state(mock_gamestate) + data = recorder.get_trajectory() + assert len(data["trajectory"]["states"]) == 1 + assert data["trajectory"]["states"][0] == mock_gamestate.as_dict + +def test_add_step(recorder, mock_action, mock_gamestate): + recorder.add_step(mock_action, reward=10.0, next_state=mock_gamestate, end_reason=None) + data = recorder.get_trajectory() + + assert len(data["trajectory"]["actions"]) == 1 + assert data["trajectory"]["actions"][0] == mock_action.as_dict + + assert len(data["trajectory"]["rewards"]) == 1 + assert data["trajectory"]["rewards"][0] == 10.0 + + assert len(data["trajectory"]["states"]) == 1 + assert data["trajectory"]["states"][0] == mock_gamestate.as_dict + + assert data["end_reason"] is None + +def test_add_step_with_end_reason(recorder, mock_action, mock_gamestate): + recorder.add_step(mock_action, reward=0, next_state=mock_gamestate, end_reason="Timeout") + data = recorder.get_trajectory() + assert data["end_reason"] == "Timeout" + +def test_reset(recorder, mock_action, mock_gamestate): + recorder.add_step(mock_action, 10, mock_gamestate) + recorder.reset() + data = recorder.get_trajectory() + + assert data["trajectory"]["states"] == [] + assert data["trajectory"]["actions"] == [] + assert data["trajectory"]["rewards"] == [] + assert data["end_reason"] is None + assert data["agent_name"] == "test_agent" + +@patch("netsecgame.utils.trajectory_recorder.store_trajectories_to_jsonl") +def test_save_to_file(mock_store, recorder): + recorder.save_to_file(location="/tmp/logs") + + # Check if called with correct args + mock_store.assert_called_once() + args, _ = mock_store.call_args + + saved_data = args[0] + location = args[1] + filename = args[2] + + assert saved_data == recorder.get_trajectory() + assert location == "/tmp/logs" + assert "test_agent_Attacker" in filename diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py new file mode 100644 index 00000000..069d782c --- /dev/null +++ b/tests/utils/test_utils.py @@ -0,0 +1,128 @@ +import pytest +import logging +from netsecgame.utils.utils import ( + get_str_hash, + state_as_ordered_string, + observation_as_dict, + observation_to_str, + observation_from_dict, + observation_from_str, + parse_log_content, + get_logging_level, + generate_valid_actions +) +from netsecgame.game_components import ( + GameState, + Observation, + ActionType, + IP, + Network, + Service, + Data +) + +# --- Fixtures --- + +@pytest.fixture +def sample_gamestate(): + net1 = Network("10.0.0.0", 24) + host1 = IP("10.0.0.1") + host2 = IP("10.0.0.2") + service1 = Service("http", "tcp", "80", False) + data1 = Data("root", "secret", "file", 100) + + return GameState( + controlled_hosts={host1}, + known_hosts={host1, host2}, + known_services={host2: {service1}}, + known_data={host1: {data1}}, + known_networks={net1}, + known_blocks={host1: {host2}} + ) + +@pytest.fixture +def sample_observation(sample_gamestate): + return Observation( + state=sample_gamestate, + reward=10.0, + end=False, + info={"reason": "test"} + ) + +# --- Tests --- + +def test_get_str_hash(): + s = "hello world" + # sha256 of "hello world" + expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9" + assert get_str_hash(s) == expected + +def test_state_as_ordered_string(sample_gamestate): + # This function produces a specific string format. + # We verify it contains expected substrings and is deterministic. + s1 = state_as_ordered_string(sample_gamestate) + s2 = state_as_ordered_string(sample_gamestate) + assert s1 == s2 + assert "nets:[10.0.0.0/24]" in s1 + assert "hosts:[10.0.0.1,10.0.0.2]" in s1 + assert "services:{10.0.0.2:[Service(name='http', type='tcp', version='80', is_local=False)]}" in s1 + +def test_observation_conversion_roundtrip(sample_observation): + # dict conversion + obs_dict = observation_as_dict(sample_observation) + assert obs_dict["reward"] == 10.0 + assert obs_dict["end"] is False + assert obs_dict["info"]["reason"] == "test" + + # restore from dict + obs_restored = observation_from_dict(obs_dict) + assert obs_restored.reward == sample_observation.reward + assert obs_restored.end == sample_observation.end + assert obs_restored.info == sample_observation.info + # State equality depends on GameState equality implementation + assert obs_restored.state.known_hosts == sample_observation.state.known_hosts + +def test_observation_json_roundtrip(sample_observation): + # str conversion + json_str = observation_to_str(sample_observation) + assert isinstance(json_str, str) + + # restore from str + obs_restored = observation_from_str(json_str) + assert obs_restored.reward == sample_observation.reward + assert obs_restored.end == sample_observation.end + assert obs_restored.state.known_hosts == sample_observation.state.known_hosts + +def test_observation_from_dict_error(): + # Invalid input + with pytest.raises(Exception): + observation_from_dict({"reward": 10}) # missing state + +def test_observation_from_str_error(): + with pytest.raises(Exception): + observation_from_str("invalid json") + +def test_parse_log_content(): + log_json = '[{"source_host": "10.0.0.1", "action_type": "ScanNetwork"}]' + logs = parse_log_content(log_json) + assert len(logs) == 1 + assert logs[0]["source_host"] == IP("10.0.0.1") + assert logs[0]["action_type"] == ActionType.ScanNetwork + +def test_parse_log_content_invalid(): + assert parse_log_content("invalid json") is None + +def test_get_logging_level(): + assert get_logging_level("DEBUG") == logging.DEBUG + assert get_logging_level("info") == logging.INFO + assert get_logging_level("UNKNOWN") == logging.ERROR + +def test_generate_valid_actions(sample_gamestate): + actions = generate_valid_actions(sample_gamestate, include_blocks=True) + assert isinstance(actions, list) + assert len(actions) > 0 + # Check for specific expected actions based on sample state + # Controlled host is 10.0.0.1 + # It should be able to ScanNetwork 10.0.0.0/24 + scan_actions = [a for a in actions if a.type == ActionType.ScanNetwork] + assert any(a.parameters["target_network"] == Network("10.0.0.0", 24) for a in scan_actions)