Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 69 additions & 12 deletions sql/engines/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
from common.utils.timer import FuncTimer
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from common.config import SysConfig
import logging

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import TransportError


logger = logging.getLogger("default")
Expand Down Expand Up @@ -583,7 +580,7 @@ def execute_check(self, db_name=None, sql=""):
sql=doc.sql,
)
else:
if is_pass == False:
if is_pass is False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot 这里的 is_pass 有点过于简单, 能否帮忙换一个变量名让他表意更明确

is_pass = True
elif not doc.api_endpoint:
if doc.method == "PUT":
Expand Down Expand Up @@ -635,10 +632,10 @@ def execute_check(self, db_name=None, sql=""):
sql=doc.sql,
)
else:
if is_pass == False:
if is_pass is False:
is_pass = True
elif doc.method == "POST":
if is_pass == False:
if is_pass is False:
is_pass = True
else:
result = ReviewResult(
Expand Down Expand Up @@ -669,7 +666,7 @@ def execute_check(self, db_name=None, sql=""):
sql=doc.sql,
)
else:
if is_pass == False:
if is_pass is False:
is_pass = True
else:
result = ReviewResult(
Expand All @@ -683,7 +680,7 @@ def execute_check(self, db_name=None, sql=""):
)
elif doc.api_endpoint == "_update_by_query":
if doc.method == "POST":
if is_pass == False:
if is_pass is False:
is_pass = True
Comment on lines +683 to 684
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_pass is False:
is_pass = True
is_pass = True

这个判断在我看来没有意义, 你看看这样简单写行不行

else:
result = ReviewResult(
Expand All @@ -695,12 +692,41 @@ def execute_check(self, db_name=None, sql=""):
affected_rows=0,
execute_time=0,
)
elif doc.api_endpoint not in ["", "_doc", "_update_by_query", "_update"]:
elif doc.api_endpoint == "_mapping":
if doc.method == "PUT":
if not doc.doc_data_body or "properties" not in doc.doc_data_body:
result = ReviewResult(
id=rowid,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="PUT请求更新索引映射时请求体必须包含properties字段。",
sql=doc.sql,
)
else:
if is_pass is False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

像这种语句可以简洁的写成

Suggested change
if is_pass is False:
if not is_pass:

is_pass = True
else:
result = ReviewResult(
id=rowid,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage=f"不支持此操作,{doc.api_endpoint}需要使用PUT方法。解析结果:{doc_desc}",
sql=doc.sql,
affected_rows=0,
execute_time=0,
)
elif doc.api_endpoint not in [
"",
"_doc",
"_update_by_query",
"_update",
"_mapping",
]:
result = ReviewResult(
id=rowid,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="API操作端点(API Endpoint)仅支持: 空, _doc、_update、_update_by_query。",
errormessage="API操作端点(API Endpoint)仅支持: 空, _doc、_update、_update_by_query、_mapping。",
sql=doc.sql,
)
else:
Expand Down Expand Up @@ -778,6 +804,10 @@ def execute_workflow(self, workflow):
reviewResult = self.__add_or_update(conn, doc)
reviewResult.id = line
execute_result.rows.append(reviewResult)
elif doc.api_endpoint == "_mapping":
reviewResult = self.__put_mapping(conn, doc)
reviewResult.id = line
execute_result.rows.append(reviewResult)
else:
raise Exception(f"不支持的API类型:{doc.api_endpoint}")
except Exception as e:
Expand Down Expand Up @@ -806,7 +836,7 @@ def execute_workflow(self, workflow):
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage=f"前序语句失败, 未执行",
errormessage="前序语句失败, 未执行",
sql=doc.sql,
affected_rows=0,
execute_time=0,
Expand Down Expand Up @@ -915,6 +945,34 @@ def __create_index(self, conn, doc):
execute_time=t.cost,
)

def __put_mapping(self, conn, doc):
"""ES的 更新索引映射方法"""
errlevel = 0
with FuncTimer() as t:
try:
response = conn.indices.put_mapping(
index=doc.index_name, body=doc.doc_data_body
)
successful_count = response.get("_shards", {}).get("successful", None)
Copy link

Copilot AI Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable successful_count is not initialized when an exception occurs and the condition 'index_not_found_exception' in error_message.lower() is False. This will cause a NameError when trying to use it in the return statement at line 972. The variable should be initialized before the try block, similar to other methods like __update and __create_index.

Copilot uses AI. Check for mistakes.
response_str = str(response)
Copy link

Copilot AI Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable response_str is not initialized when an exception occurs and the condition 'index_not_found_exception' in error_message.lower() is False. This will cause a NameError when trying to use it in the return statement at line 970. The variable should be initialized before the try block.

Copilot uses AI. Check for mistakes.
except Exception as e:
error_message = str(e)
if "index_not_found_exception" in error_message.lower():
response_str = "index not found: " + error_message
successful_count = 0
errlevel = 2
else:
raise

return ReviewResult(
errlevel=errlevel,
stagestatus="Execute Successfully",
errormessage=response_str,
sql=doc.sql,
affected_rows=successful_count if successful_count is not None else 0,
execute_time=t.cost,
)

def __delete_data(self, conn, doc):
"""
数据删除
Expand Down Expand Up @@ -951,7 +1009,6 @@ def __get_document_from_sql(self, sql):
"""
result = ElasticsearchDocument(sql=sql)
if re.match(r"^POST |^PUT |^DELETE ", sql, re.I):

# 提取方法和路径
method, path_with_params = sql.split(maxsplit=1)
if path_with_params.startswith("{"):
Expand Down
192 changes: 192 additions & 0 deletions sql/engines/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,195 @@ def test_execute_workflow_delete_data_exception(self, mockElasticsearch):
self.assertEqual(result.rows[0].errlevel, 1)
self.assertIn("Execute Successfully", result.rows[0].stagestatus)
self.assertIn("Document not found", result.rows[0].errormessage)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_put_mapping_with_properties(self, mockElasticsearch):
"""测试 PUT _mapping 请求,包含 properties 字段"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = 'PUT /test_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
result = self.engine.execute_check(sql=sql)

self.assertEqual(result.error_count, 0)
self.assertEqual(result.rows[0].errlevel, 0)
self.assertIn("审核通过", result.rows[0].errormessage)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_put_mapping_without_properties(self, mockElasticsearch):
"""测试 PUT _mapping 请求,不包含 properties 字段"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = 'PUT /test_index/_mapping {"settings": {"number_of_shards": 1}}'
result = self.engine.execute_check(sql=sql)

self.assertEqual(result.error_count, 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("必须包含properties字段", result.rows[0].errormessage.lower())

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_put_mapping_empty_body(self, mockElasticsearch):
"""测试 PUT _mapping 请求,请求体为空"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = "PUT /test_index/_mapping"
result = self.engine.execute_check(sql=sql)

self.assertEqual(result.error_count, 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("必须包含properties字段", result.rows[0].errormessage.lower())

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_put_mapping_empty_dict(self, mockElasticsearch):
"""测试 PUT _mapping 请求,请求体为空字典"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = "PUT /test_index/_mapping {}"
result = self.engine.execute_check(sql=sql)

self.assertEqual(result.error_count, 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("必须包含properties字段", result.rows[0].errormessage.lower())

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_post_mapping(self, mockElasticsearch):
"""测试 POST _mapping 请求(不支持)"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = (
'POST /test_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
)
result = self.engine.execute_check(sql=sql)

self.assertEqual(result.error_count, 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("需要使用PUT方法", result.rows[0].errormessage)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_delete_mapping(self, mockElasticsearch):
"""测试 DELETE _mapping 请求(不支持,因为 DELETE 方法需要 doc_id)"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = "DELETE /test_index/_mapping"
result = self.engine.execute_check(sql=sql)

# DELETE 方法的检查在 _mapping 检查之前,所以会先检查 doc_id
# 由于没有 doc_id,会返回"删除操作必须包含id条件。"的错误
self.assertEqual(result.error_count, 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("删除操作必须包含id条件", result.rows[0].errormessage)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_check_invalid_endpoint_with_mapping_in_message(
self, mockElasticsearch
):
"""测试不支持的 API 端点,错误消息中包含 _mapping"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

sql = 'PUT /test_index/_invalid_endpoint {"properties": {"new_field": {"type": "text"}}}'
result = self.engine.execute_check(sql=sql)

self.assertEqual(result.error_count, 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("_mapping", result.rows[0].errormessage)
self.assertIn("API操作端点", result.rows[0].errormessage)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_workflow_put_mapping_success(self, mockElasticsearch):
"""测试 execute_workflow 方法的 PUT _mapping 请求执行成功"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

workflow = Mock()
workflow.sqlworkflowcontent.sql_content = (
'PUT /test_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
)
workflow.db_name = "test_db"

mock_conn.indices.put_mapping.return_value = {
"acknowledged": True,
"_shards": {"successful": 1, "failed": 0},
}

result = self.engine.execute_workflow(workflow)

self.assertEqual(len(result.rows), 1)
self.assertEqual(result.rows[0].errlevel, 0)
self.assertIn("Execute Successfully", result.rows[0].stagestatus)
self.assertEqual(result.rows[0].affected_rows, 1)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_workflow_put_mapping_index_not_found(self, mockElasticsearch):
"""测试 execute_workflow 方法的 PUT _mapping 请求,索引不存在"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

workflow = Mock()
workflow.sqlworkflowcontent.sql_content = 'PUT /nonexistent_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
Copy link

Copilot AI Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This line exceeds typical line length conventions (appears to be >100 characters). Consider breaking the SQL string into a multi-line format for better readability, consistent with the formatting used in other test methods like test_execute_workflow_put_mapping_success.

Suggested change
workflow.sqlworkflowcontent.sql_content = 'PUT /nonexistent_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
workflow.sqlworkflowcontent.sql_content = (
'PUT /nonexistent_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
)

Copilot uses AI. Check for mistakes.
workflow.db_name = "test_db"

mock_conn.indices.put_mapping.side_effect = Exception(
"index_not_found_exception: Index not found"
)

result = self.engine.execute_workflow(workflow)

self.assertEqual(len(result.rows), 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("Execute Successfully", result.rows[0].stagestatus)
self.assertIn("index not found", result.rows[0].errormessage)
self.assertEqual(result.rows[0].affected_rows, 0)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_workflow_put_mapping_no_shards(self, mockElasticsearch):
"""测试 execute_workflow 方法的 PUT _mapping 请求,响应中没有 _shards"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

workflow = Mock()
workflow.sqlworkflowcontent.sql_content = (
'PUT /test_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
)
workflow.db_name = "test_db"

mock_conn.indices.put_mapping.return_value = {
"acknowledged": True,
}

result = self.engine.execute_workflow(workflow)

self.assertEqual(len(result.rows), 1)
self.assertEqual(result.rows[0].errlevel, 0)
self.assertIn("Execute Successfully", result.rows[0].stagestatus)
self.assertEqual(result.rows[0].affected_rows, 0)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_execute_workflow_put_mapping_other_exception(self, mockElasticsearch):
"""测试 execute_workflow 方法的 PUT _mapping 请求,其他异常(非 index_not_found)"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

workflow = Mock()
workflow.sqlworkflowcontent.sql_content = (
'PUT /test_index/_mapping {"properties": {"new_field": {"type": "text"}}}'
)
workflow.db_name = "test_db"

# 模拟其他类型的异常(不是 index_not_found_exception)
mock_conn.indices.put_mapping.side_effect = Exception(
"mapper_parsing_exception: Failed to parse mapping"
)

result = self.engine.execute_workflow(workflow)

# 异常应该被重新抛出并在 execute_workflow 中被捕获
self.assertEqual(len(result.rows), 1)
self.assertEqual(result.rows[0].errlevel, 2)
self.assertIn("Execute Failed", result.rows[0].stagestatus)
self.assertIn("异常信息", result.rows[0].errormessage)
Loading