diff --git a/gui_agents/s3/core/mllm.py b/gui_agents/s3/core/mllm.py index fb49e4b0..58beb93c 100644 --- a/gui_agents/s3/core/mllm.py +++ b/gui_agents/s3/core/mllm.py @@ -35,6 +35,69 @@ def __init__(self, engine_params=None, system_prompt=None, engine=None): self.engine = LMMEngineOpenRouter(**engine_params) elif engine_type == "parasail": self.engine = LMMEngineParasail(**engine_params) + elif engine_type == "ollama": + # Reuse LMMEngineOpenAI for Ollama + if not engine_params.get("base_url"): + import os + + base_url = os.getenv("OLLAMA_HOST") + if base_url: + if not base_url.endswith("/v1"): + base_url = base_url.rstrip("/") + "/v1" + engine_params["base_url"] = base_url + else: + # RAISE ERROR instead of default + raise ValueError( + "Ollama endpoint must be provided via 'base_url' parameter or 'OLLAMA_HOST' environment variable." + ) + if not engine_params.get("api_key"): + engine_params["api_key"] = "ollama" + self.engine = LMMEngineOpenAI(**engine_params) + elif engine_type == "deepseek": + if "base_url" not in engine_params: + import os + + base_url = os.getenv("DEEPSEEK_ENDPOINT_URL") + if not base_url: + base_url = "https://api.deepseek.com" + if not base_url.endswith("/v1"): + base_url = base_url.rstrip("/") + "/v1" + engine_params["base_url"] = base_url + + if not engine_params.get("api_key"): + import os + + api_key = os.getenv("DEEPSEEK_API_KEY") + if not api_key: + raise ValueError( + "DeepSeek API key must be provided via 'api_key' parameter or 'DEEPSEEK_API_KEY' environment variable." + ) + engine_params["api_key"] = api_key + + self.engine = LMMEngineOpenAI(**engine_params) + elif engine_type == "qwen": + if not engine_params.get("base_url"): + import os + + base_url = os.getenv("QWEN_ENDPOINT_URL") + if not base_url: + base_url = ( + "https://dashscope.aliyuncs.com/compatible-mode/v1" + ) + if not base_url.endswith("/v1"): + base_url = base_url.rstrip("/") + "/v1" + engine_params["base_url"] = base_url + + if not engine_params.get("api_key"): + import os + + api_key = os.getenv("QWEN_API_KEY") + if not api_key: + raise ValueError( + "Qwen API key must be provided via 'api_key' parameter or 'QWEN_API_KEY' environment variable." + ) + engine_params["api_key"] = api_key + self.engine = LMMEngineOpenAI(**engine_params) else: raise ValueError(f"engine_type '{engine_type}' is not supported") else: diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 00000000..2e1b24d2 --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,72 @@ +import os +import unittest +from unittest.mock import patch, MagicMock +from gui_agents.s3.core.mllm import LMMAgent +from gui_agents.s3.core.engine import LMMEngineOpenAI + + +class TestProviders(unittest.TestCase): + def setUp(self): + # Clear env vars before each test + if "OLLAMA_HOST" in os.environ: + del os.environ["OLLAMA_HOST"] + if "DEEPSEEK_API_KEY" in os.environ: + del os.environ["DEEPSEEK_API_KEY"] + if "QWEN_API_KEY" in os.environ: + del os.environ["QWEN_API_KEY"] + if "DEEPSEEK_ENDPOINT_URL" in os.environ: + del os.environ["DEEPSEEK_ENDPOINT_URL"] + if "QWEN_ENDPOINT_URL" in os.environ: + del os.environ["QWEN_ENDPOINT_URL"] + + def test_ollama_missing_config(self): + """Test that Ollama raises ValueError if no endpoint is provided""" + with self.assertRaises(ValueError) as cm: + LMMAgent(engine_params={"engine_type": "ollama", "model": "llama3"}) + self.assertIn("Ollama endpoint must be provided", str(cm.exception)) + + def test_ollama_valid_config_param(self): + """Test Ollama init with base_url param""" + agent = LMMAgent( + engine_params={ + "engine_type": "ollama", + "model": "llama3", + "base_url": "http://example.com/v1", + } + ) + self.assertIsInstance(agent.engine, LMMEngineOpenAI) + self.assertEqual(agent.engine.base_url, "http://example.com/v1") + + def test_ollama_valid_config_env(self): + """Test Ollama init with OLLAMA_HOST env var""" + with patch.dict(os.environ, {"OLLAMA_HOST": "http://env-host:11434"}): + agent = LMMAgent(engine_params={"engine_type": "ollama", "model": "llama3"}) + self.assertIsInstance(agent.engine, LMMEngineOpenAI) + # Check for /v1 addition + self.assertEqual(agent.engine.base_url, "http://env-host:11434/v1") + + def test_deepseek_init(self): + """Test DeepSeek initialization""" + with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "sk-test"}): + agent = LMMAgent( + engine_params={"engine_type": "deepseek", "model": "deepseek-coder"} + ) + self.assertIsInstance(agent.engine, LMMEngineOpenAI) + # Default URL + self.assertEqual(agent.engine.base_url, "https://api.deepseek.com/v1") + # (Note: engine.py logic resolves default at generate() time or if client created, + # but init just stores what's passed. Let's verify prompt generation to ensure it doesn't crash on init) + + def test_qwen_init(self): + """Test Qwen initialization""" + with patch.dict(os.environ, {"QWEN_API_KEY": "sk-qwen"}): + agent = LMMAgent(engine_params={"engine_type": "qwen", "model": "qwen-max"}) + self.assertIsInstance(agent.engine, LMMEngineOpenAI) + self.assertEqual( + agent.engine.base_url, + "https://dashscope.aliyuncs.com/compatible-mode/v1", + ) + + +if __name__ == "__main__": + unittest.main()