diff --git a/lexsi_sdk/common/types.py b/lexsi_sdk/common/types.py index 03a13a2..ea60dc4 100644 --- a/lexsi_sdk/common/types.py +++ b/lexsi_sdk/common/types.py @@ -162,9 +162,6 @@ class ProjectConfig(TypedDict): """ Configuration keys required to describe a project. - :param project_type: Project type identifier. - :type project_type: Optional[str] - :param model_name: Model name associated with the project. :type model_name: Optional[str] diff --git a/lexsi_sdk/core/image.py b/lexsi_sdk/core/image.py index d2e9f88..b75545a 100644 --- a/lexsi_sdk/core/image.py +++ b/lexsi_sdk/core/image.py @@ -562,10 +562,10 @@ def update_inference_model_status(self, model_name: str, activate: bool) -> str: def model_inference( self, + pod: str, tag: Optional[str] = None, file_name: Optional[str] = None, model_name: Optional[str] = None, - pod: Optional[str] = None ) -> pd.DataFrame: """Run model inference on tag or file_name data. Either tag or file_name is required for running inference @@ -622,9 +622,8 @@ def model_inference( "project_name": self.project_name, "model_name": model, "tags": tag, + "instance_type": pod, } - if pod: - run_model_payload["instance_type"] = pod if filepath: run_model_payload["filepath"] = filepath diff --git a/lexsi_sdk/core/project.py b/lexsi_sdk/core/project.py index 79c426c..e32313c 100755 --- a/lexsi_sdk/core/project.py +++ b/lexsi_sdk/core/project.py @@ -12,7 +12,7 @@ GDriveConfig, SFTPConfig, ) -from lexsi_sdk.common.utils import normalize_time, parse_datetime, parse_float, poll_events +from lexsi_sdk.common.utils import normalize_time, parse_float, poll_events from lexsi_sdk.common.validation import Validate import pandas as pd from lexsi_sdk.common.xai_uris import ( @@ -46,7 +46,6 @@ LIST_FILEPATHS, UPLOAD_FILE_DATA_CONNECTORS, DROPBOX_OAUTH, - VALIDATE_POLICY_URI, ) import io from lexsi_sdk.core.alert import Alert @@ -1062,78 +1061,3 @@ def build_expression(expression_string): configuration.append(logical_operators[log_operator]) return configuration, metadata_expression - - -def validate_configuration( - configuration, params, project_name="", api_client=APIClient(), observations=False -): - """Validate an expression provided configuration against allowed features/operators. - Raises exceptions for invalid columns/operators/values and can validate observation comparisons. - - :param configuration: Configuration token list (from `build_expression`). - :param params: Allowed features/operators payload fetched from the backend. - :param project_name: Project name used for backend validation calls. - :param api_client: API client used for optional backend validation. - :param observations: If True, validate observation column-vs-column comparisons. - :raises Exception: If the configuration is invalid.""" - for expression in configuration: - if isinstance(expression, str): - if expression not in ["(", ")", *params.get("logical_operators")]: - raise Exception(f"{expression} not a valid logical operator") - - if isinstance(expression, dict): - # validate column name - Validate.value_against_list( - "feature", - expression.get("column"), - list(params.get("features", {}).keys()), - ) - - # validate operator - Validate.value_against_list( - "condition_operator", - expression.get("expression"), - params.get("condition_operators"), - ) - - # validate value(s) - expression_value = expression.get("value") - valid_feature_values = params.get("features").get(expression.get("column")) - if observations and isinstance(valid_feature_values, list): - condition_operators = { - "_NOTEQ": "!==", - "_ISEQ": "==", - "_GRT": ">", - "_LST": "<", - } - res = api_client.get( - f"{VALIDATE_POLICY_URI}?project_name={project_name}&column1_name={expression.get('column')}&column2_name={expression.get('value')}&operation={condition_operators[expression.get('expression')]}" - ) - if not res.get("success"): - raise Exception(res.get("message")) - if isinstance(valid_feature_values, str): - # if valid_feature_values == "input" and not parse_float( - # expression_value - # ): - # raise Exception( - # f"Invalid value comparison with {expression_value} for {expression.get('column')}" - # ) - if valid_feature_values == "datetime" and not parse_datetime( - expression_value - ): - raise Exception( - f"Invalid value comparison with {expression_value} for {expression.get('column')}" - ) - - else: - condition_operators = { - "_NOTEQ": "!==", - "_ISEQ": "==", - "_GRT": ">", - "_LST": "<", - } - res = api_client.get( - f"{VALIDATE_POLICY_URI}?project_name={project_name}&column1_name={expression.get('column')}&column2_name={expression.get('value')}&operation={condition_operators[expression.get('expression')]}" - ) - if not res.get("success"): - raise Exception(res.get("message")) diff --git a/lexsi_sdk/core/tabular.py b/lexsi_sdk/core/tabular.py index 97cefa6..18cab86 100644 --- a/lexsi_sdk/core/tabular.py +++ b/lexsi_sdk/core/tabular.py @@ -11,9 +11,9 @@ from lexsi_sdk.common.types import CatBoostParams, DataConfig, FoundationalModelParams, InferenceCompute, LightGBMParams, PEFTParams, ProcessorParams, ProjectConfig, RandomForestParams, SyntheticDataConfig, SyntheticModelHyperParams, TuningParams, XGBoostParams from lexsi_sdk.common.utils import normalize_time, poll_events from lexsi_sdk.common.validation import Validate -from lexsi_sdk.common.xai_uris import ALL_DATA_FILE_URI, AVAILABLE_BATCH_SERVERS_URI, AVAILABLE_SYNTHETIC_CUSTOM_SERVERS_URI, CASE_DTREE_URI, CASE_INFO_TEXT_URI, CASE_INFO_URI, CREATE_OBSERVATION_URI, CREATE_POLICY_URI, CREATE_SYNTHETIC_PROMPT_URI, DELETE_CASE_URI, DELETE_SYNTHETIC_MODEL_URI, DELETE_SYNTHETIC_TAG_URI, DOWNLOAD_DASHBOARD_LOGS_URI, DOWNLOAD_SYNTHETIC_DATA_URI, DOWNLOAD_TAG_DATA_URI, DUPLICATE_OBSERVATION_URI, DUPLICATE_POLICY_URI, GENERATE_DASHBOARD_URI, GET_CASES_URI, GET_DASHBOARD_SCORE_URI, GET_DATA_DIAGNOSIS_URI, GET_DATA_DRIFT_DIAGNOSIS_URI, GET_DATA_SUMMARY_URI, GET_FEATURE_IMPORTANCE_URI, GET_LABELS_URI, GET_MODELS_URI, GET_OBSERVATION_PARAMS_URI, GET_OBSERVATIONS_URI, GET_POLICIES_URI, GET_POLICY_PARAMS_URI, GET_PROJECT_CONFIG, GET_SYNTHETIC_DATA_TAGS_URI, GET_SYNTHETIC_MODEL_DETAILS_URI, GET_SYNTHETIC_MODEL_PARAMS_URI, GET_SYNTHETIC_MODELS_URI, GET_SYNTHETIC_PROMPT_URI, LIST_DATA_CONNECTORS, MODEL_INFERENCE_SETTINGS_URI, MODEL_INFERENCES_URI, MODEL_PARAMETERS_URI, MODEL_SUMMARY_URI, PROJECT_OVERVIEW_TEXT_URI, RUN_DATA_DRIFT_DIAGNOSIS_URI, RUN_MODEL_ON_DATA_URI, SEARCH_CASE_URI, TABULAR_ML, TEXT_MODEL_INFERENCE_SETTINGS_URI, TRAIN_MODEL_URI, TRAIN_SYNTHETIC_MODEL_URI, UPDATE_ACTIVE_INFERENCE_MODEL_URI, UPDATE_OBSERVATION_URI, UPDATE_POLICY_URI, UPDATE_SYNTHETIC_PROMPT_URI, UPLOAD_DATA_FILE_URI, UPLOAD_DATA_PROJECT_URI, UPLOAD_DATA_URI, UPLOAD_FILE_DATA_CONNECTORS, AVAILABLE_BATCH_SERVERS_URI, CREATE_TRIGGER_URI, DASHBOARD_LOGS_URI, DELETE_TRIGGER_URI, DUPLICATE_MONITORS_URI, EXECUTED_TRIGGER_URI, GENERATE_DASHBOARD_URI, GET_DASHBOARD_SCORE_URI, GET_DASHBOARD_URI, GET_EXECUTED_TRIGGER_INFO, GET_MODEL_TYPES_URI, GET_MODELS_URI, GET_MONITORS_ALERTS, GET_PROJECT_CONFIG, GET_TRIGGERS_URI, LIST_DATA_CONNECTORS, MODEL_PARAMETERS_URI, MODEL_PERFORMANCE_DASHBOARD_URI, UPLOAD_DATA_FILE_INFO_URI, UPLOAD_DATA_FILE_URI, UPLOAD_DATA_URI, UPLOAD_DATA_WITH_CHECK_URI, UPLOAD_FILE_DATA_CONNECTORS, UPLOAD_MODEL_URI, EXPLAINABILITY_SUMMARY, GET_TRIGGERS_DAYS_URI +from lexsi_sdk.common.xai_uris import ALL_DATA_FILE_URI, AVAILABLE_BATCH_SERVERS_URI, AVAILABLE_SYNTHETIC_CUSTOM_SERVERS_URI, CASE_DTREE_URI, CASE_INFO_TEXT_URI, CASE_INFO_URI, CREATE_OBSERVATION_URI, CREATE_POLICY_URI, CREATE_SYNTHETIC_PROMPT_URI, DELETE_CASE_URI, DELETE_SYNTHETIC_MODEL_URI, DELETE_SYNTHETIC_TAG_URI, DOWNLOAD_DASHBOARD_LOGS_URI, DOWNLOAD_SYNTHETIC_DATA_URI, DOWNLOAD_TAG_DATA_URI, DUPLICATE_OBSERVATION_URI, DUPLICATE_POLICY_URI, GENERATE_DASHBOARD_URI, GET_CASES_URI, GET_DASHBOARD_SCORE_URI, GET_DATA_DIAGNOSIS_URI, GET_DATA_DRIFT_DIAGNOSIS_URI, GET_DATA_SUMMARY_URI, GET_FEATURE_IMPORTANCE_URI, GET_LABELS_URI, GET_MODELS_URI, GET_OBSERVATION_PARAMS_URI, GET_OBSERVATIONS_URI, GET_POLICIES_URI, GET_PROJECT_CONFIG, GET_SYNTHETIC_DATA_TAGS_URI, GET_SYNTHETIC_MODEL_DETAILS_URI, GET_SYNTHETIC_MODEL_PARAMS_URI, GET_SYNTHETIC_MODELS_URI, GET_SYNTHETIC_PROMPT_URI, LIST_DATA_CONNECTORS, MODEL_INFERENCE_SETTINGS_URI, MODEL_INFERENCES_URI, MODEL_PARAMETERS_URI, MODEL_SUMMARY_URI, PROJECT_OVERVIEW_TEXT_URI, RUN_DATA_DRIFT_DIAGNOSIS_URI, RUN_MODEL_ON_DATA_URI, SEARCH_CASE_URI, TABULAR_ML, TEXT_MODEL_INFERENCE_SETTINGS_URI, TRAIN_MODEL_URI, TRAIN_SYNTHETIC_MODEL_URI, UPDATE_ACTIVE_INFERENCE_MODEL_URI, UPDATE_OBSERVATION_URI, UPDATE_POLICY_URI, UPDATE_SYNTHETIC_PROMPT_URI, UPLOAD_DATA_FILE_URI, UPLOAD_DATA_PROJECT_URI, UPLOAD_DATA_URI, UPLOAD_FILE_DATA_CONNECTORS, AVAILABLE_BATCH_SERVERS_URI, CREATE_TRIGGER_URI, DASHBOARD_LOGS_URI, DELETE_TRIGGER_URI, DUPLICATE_MONITORS_URI, EXECUTED_TRIGGER_URI, GENERATE_DASHBOARD_URI, GET_DASHBOARD_SCORE_URI, GET_DASHBOARD_URI, GET_EXECUTED_TRIGGER_INFO, GET_MODEL_TYPES_URI, GET_MODELS_URI, GET_MONITORS_ALERTS, GET_PROJECT_CONFIG, GET_TRIGGERS_URI, LIST_DATA_CONNECTORS, MODEL_PARAMETERS_URI, MODEL_PERFORMANCE_DASHBOARD_URI, UPLOAD_DATA_FILE_INFO_URI, UPLOAD_DATA_FILE_URI, UPLOAD_DATA_URI, UPLOAD_DATA_WITH_CHECK_URI, UPLOAD_FILE_DATA_CONNECTORS, UPLOAD_MODEL_URI, EXPLAINABILITY_SUMMARY, GET_TRIGGERS_DAYS_URI from lexsi_sdk.core.dashboard import DASHBOARD_TYPES, Dashboard -from lexsi_sdk.core.project import Project, build_expression, generate_expression, validate_configuration +from lexsi_sdk.core.project import Project, build_expression, generate_expression from lexsi_sdk.core.synthetic import SyntheticDataTag, SyntheticModel, SyntheticPrompt from lexsi_sdk.core.utils import build_list_data_connector_url from pydantic import BaseModel, ConfigDict @@ -315,7 +315,6 @@ def upload_file_and_return_path(data, data_type, tag=None) -> str: if project_config == "Not Found": if not config: config = { - "project_type": "", "unique_identifier": "", "true_label": "", "pred_label": "", @@ -332,66 +331,22 @@ def upload_file_and_return_path(data, data_type, tag=None) -> str: config, ["unique_identifier", "true_label"] ) - # Validate.value_against_list( - # "project_type", config, ["classification", "regression"] - # ) - - uploaded_path = upload_file_and_return_path(data, "data", tag) - - file_info = self.api_client.post( - UPLOAD_DATA_FILE_INFO_URI, {"path": uploaded_path} - ) - - column_names = file_info.get("details").get("column_names") - - Validate.value_against_list( - "unique_identifier", - config["unique_identifier"], - column_names, - lambda: self.delete_file(uploaded_path), - ) - - if config.get("feature_exclude"): - Validate.value_against_list( - "feature_exclude", - config["feature_exclude"], - column_names, - lambda: self.delete_file(uploaded_path), - ) - - feature_exclude = [ - config["unique_identifier"], - config["true_label"], - *config.get("feature_exclude", []), - ] - - feature_include = [ - feature - for feature in column_names - if feature not in feature_exclude - ] - feature_encodings = config.get("feature_encodings", {}) if feature_encodings: - Validate.value_against_list( - "feature_encodings_feature", - list(feature_encodings.keys()), - column_names, - ) Validate.value_against_list( "feature_encodings_feature", list(feature_encodings.values()), ["labelencode", "countencode", "onehotencode"], ) + custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) available_custom_batch_servers = custom_batch_servers.get("details", []) + custom_batch_servers.get("available_gpu_custom_servers", []) - + if config.get("model_name") and config.get("model_name") in ["TabPFN","TabICL","TabDPT","OrionMSP", "OrionBix","Mitra", "ContextTab"] and not compute_type: valid_list = [ server["instance_name"] for server in available_custom_batch_servers ] - self.delete_file(uploaded_path) raise Exception(f"For Foundational models compute_type is mandatory. select from \n {valid_list}") if tunning_strategy != "inference" and compute_type and "gova" not in compute_type: @@ -404,10 +359,30 @@ def upload_file_and_return_path(data, data_type, tag=None) -> str: ], ) + uploaded_path = upload_file_and_return_path(data, "data", tag) + + file_info = self.api_client.post( + UPLOAD_DATA_FILE_INFO_URI, {"path": uploaded_path} + ) + + column_names = file_info.get("details").get("column_names") + + feature_exclude = [ + config["unique_identifier"], + config["true_label"], + *config.get("feature_exclude", []), + ] + + feature_include = [ + feature + for feature in column_names + if feature not in feature_exclude + ] + payload = { "project_name": self.project_name, - "unique_identifier": config["unique_identifier"], - "true_label": config["true_label"], + "unique_identifier": config.get("unique_identifier"), + "true_label": config.get("true_label"), "pred_label": config.get("pred_label"), "metadata": { "path": uploaded_path, @@ -446,7 +421,11 @@ def upload_file_and_return_path(data, data_type, tag=None) -> str: payload["metadata"]["finetune_mode"] = finetune_mode if tunning_strategy: payload["metadata"]["tunning_strategy"] = tunning_strategy - res = self.api_client.post(UPLOAD_DATA_WITH_CHECK_URI, payload) + try: + res = self.api_client.post(UPLOAD_DATA_WITH_CHECK_URI, payload) + except Exception as e: + self.delete_file(uploaded_path) + raise e if not res["success"]: self.delete_file(uploaded_path) @@ -469,14 +448,18 @@ def upload_file_and_return_path(data, data_type, tag=None) -> str: "type": "data", "project_name": self.project_name, } - res = self.api_client.post(UPLOAD_DATA_URI, payload) + try: + res = self.api_client.post(UPLOAD_DATA_URI, payload) + except Exception as e: + self.delete_file(uploaded_path) + raise e if not res["success"]: self.delete_file(uploaded_path) raise Exception(res.get("details")) return res.get("details") - + def upload_data_dataconnectors( self, data_connector_name: str, @@ -499,7 +482,6 @@ def upload_data_dataconnectors( :param file_path: filepath from the bucket for the data to read :param config: project config { - "project_type": "", "unique_identifier": "", "true_label": "", "pred_label": "", @@ -603,11 +585,8 @@ def upload_file_and_return_path(file_path, data_type, tag=None) -> str: project_config = self.config() if project_config == "Not Found": - if not config.get("project_type"): - config["project_type"] = self.metadata.get("project_type") if not config: config = { - "project_type": "", "unique_identifier": "", "true_label": "", "pred_label": "", @@ -620,55 +599,11 @@ def upload_file_and_return_path(file_path, data_type, tag=None) -> str: ) Validate.check_for_missing_keys( - config, ["project_type", "unique_identifier", "true_label"] - ) - - Validate.value_against_list( - "project_type", config, ["classification", "regression"] - ) - - uploaded_path = upload_file_and_return_path(file_path, "data", tag) - - file_info = self.api_client.post( - UPLOAD_DATA_FILE_INFO_URI, {"path": uploaded_path} - ) - - column_names = file_info.get("details").get("column_names") - - Validate.value_against_list( - "unique_identifier", - config["unique_identifier"], - column_names, - lambda: self.delete_file(uploaded_path), + config, ["unique_identifier", "true_label"] ) - if config.get("feature_exclude"): - Validate.value_against_list( - "feature_exclude", - config["feature_exclude"], - column_names, - lambda: self.delete_file(uploaded_path), - ) - - feature_exclude = [ - config["unique_identifier"], - config["true_label"], - *config.get("feature_exclude", []), - ] - - feature_include = [ - feature - for feature in column_names - if feature not in feature_exclude - ] - feature_encodings = config.get("feature_encodings", {}) if feature_encodings: - Validate.value_against_list( - "feature_encodings_feature", - list(feature_encodings.keys()), - column_names, - ) Validate.value_against_list( "feature_encodings_feature", list(feature_encodings.values()), @@ -677,13 +612,12 @@ def upload_file_and_return_path(file_path, data_type, tag=None) -> str: custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) available_custom_batch_servers = custom_batch_servers.get("details", []) + custom_batch_servers.get("available_gpu_custom_servers", []) - + if config.get("model_name") and config.get("model_name") in ["TabPFN","TabICL","TabDPT","OrionMSP", "OrionBix","Mitra", "ContextTab"] and not compute_type: valid_list = [ server["instance_name"] for server in available_custom_batch_servers ] - self.delete_file(uploaded_path) raise Exception(f"For Foundational models compute_type is mandatory. select from \n {valid_list}") if tunning_strategy != "inference" and compute_type and "gova" not in compute_type: @@ -696,6 +630,26 @@ def upload_file_and_return_path(file_path, data_type, tag=None) -> str: ], ) + uploaded_path = upload_file_and_return_path(file_path, "data", tag) + + file_info = self.api_client.post( + UPLOAD_DATA_FILE_INFO_URI, {"path": uploaded_path} + ) + + column_names = file_info.get("details").get("column_names") + + feature_exclude = [ + config["unique_identifier"], + config["true_label"], + *config.get("feature_exclude", []), + ] + + feature_include = [ + feature + for feature in column_names + if feature not in feature_exclude + ] + payload = { "project_name": self.project_name, "unique_identifier": config["unique_identifier"], @@ -729,13 +683,21 @@ def upload_file_and_return_path(file_path, data_type, tag=None) -> str: if tunning_strategy: payload["metadata"]["tunning_strategy"] = tunning_strategy - res = self.api_client.post(UPLOAD_DATA_WITH_CHECK_URI, payload) + try: + res = self.api_client.post(UPLOAD_DATA_WITH_CHECK_URI, payload) + except Exception as e: + self.delete_file(uploaded_path) + raise e if not res["success"]: self.delete_file(uploaded_path) raise Exception(res.get("details")) - poll_events(self.api_client, self.project_name, res["event_id"]) + try: + poll_events(self.api_client, self.project_name, res["event_id"]) + except Exception as e: + self.delete_file(uploaded_path) + raise e return res.get("details") @@ -750,7 +712,11 @@ def upload_file_and_return_path(file_path, data_type, tag=None) -> str: "type": "data", "project_name": self.project_name, } - res = self.api_client.post(UPLOAD_DATA_URI, payload) + try: + res = self.api_client.post(UPLOAD_DATA_URI, payload) + except Exception as e: + self.delete_file(uploaded_path) + raise e if not res["success"]: self.delete_file(uploaded_path) @@ -775,8 +741,8 @@ def upload_model( model_type: str, model_name: str, model_train: list, - model_test: Optional[list], - pod: Optional[str] = None, + pod: str, + model_test: Optional[list] = None, xai_method: Optional[list] = ["shap"], feature_list: Optional[list] = None, ): @@ -788,8 +754,8 @@ def upload_model( use upload_model_types() method to get all allowed model_types :param model_name: name of the model :param model_train: data tags for model + :param pod: pod to be used for uploading model (required) :param model_test: test tags for model (optional) - :param pod: pod to be used for uploading model (optional) :param xai_method: xai method to be used while uploading model ["shap", "lime"] (optional) :param feature_list: list of features in sequence which are to be passed in the model (optional) """ @@ -825,24 +791,26 @@ def upload_file_and_return_path() -> str: if model_test: Validate.value_against_list("model_test", model_test, tags) - uploaded_path = upload_file_and_return_path() + if not pod: + raise Exception("pod is required to upload a model.") - if pod: - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - Validate.value_against_list( - "pod", - pod, - [ - server["instance_name"] - for server in custom_batch_servers.get("details", []) - ], - ) + custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) + Validate.value_against_list( + "pod", + pod, + [ + server["instance_name"] + for server in custom_batch_servers.get("details", []) + ], + ) if xai_method: Validate.value_against_list( "explainability_method", xai_method, ["shap", "lime", "ig", "dlb"] ) + uploaded_path = upload_file_and_return_path() + payload = { "project_name": self.project_name, "model_name": model_name, @@ -853,11 +821,9 @@ def upload_file_and_return_path() -> str: "model_test_tags": model_test, "explainability_method": xai_method, "feature_list": feature_list, + "instance_type": pod, } - if pod: - payload["instance_type"] = pod - res = self.api_client.post(UPLOAD_MODEL_URI, payload) if not res.get("success"): @@ -878,7 +844,7 @@ def upload_model_dataconnectors( model_name: str, model_train: list, model_test: Optional[list], - pod: Optional[str] = None, + pod: str, xai_method: Optional[list] = ["shap"], bucket_name: Optional[str] = None, file_path: Optional[str] = None, @@ -892,7 +858,7 @@ def upload_model_dataconnectors( :param model_name: name of the model :param model_train: data tags for model :param model_test: test tags for model (optional) - :param pod: pod to be used for uploading model (optional) + :param pod: pod to be used for uploading model (required) :param xai_method: explainability method to be used while uploading model ["shap", "lime"] (optional) :param bucket_name: if data connector has buckets # Example: s3/gcs buckets :param file_path: filepath from the bucket for the data to read @@ -951,7 +917,7 @@ def upload_file_and_return_path() -> str: uploaded_path = res.get("metadata").get("filepath") return uploaded_path - + model_types = self.api_client.get(GET_MODEL_TYPES_URI) valid_model_architecture = model_types.get("model_architecture").keys() Validate.value_against_list( @@ -967,24 +933,26 @@ def upload_file_and_return_path() -> str: if model_test: Validate.value_against_list("model_test", model_test, tags) - uploaded_path = upload_file_and_return_path() + if not pod: + raise Exception("pod is required to upload a model.") - if pod: - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - Validate.value_against_list( - "pod", - pod, - [ - server["instance_name"] - for server in custom_batch_servers.get("details", []) - ], - ) + custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) + Validate.value_against_list( + "pod", + pod, + [ + server["instance_name"] + for server in custom_batch_servers.get("details", []) + ], + ) if xai_method: Validate.value_against_list( "explainability_method", xai_method, ["shap", "lime"] ) + uploaded_path = upload_file_and_return_path() + payload = { "project_name": self.project_name, "model_name": model_name, @@ -994,11 +962,9 @@ def upload_file_and_return_path() -> str: "model_data_tags": model_train, "model_test_tags": model_test, "explainability_method": xai_method, + "instance_type": pod, } - if pod: - payload["instance_type"] = pod - res = self.api_client.post(UPLOAD_MODEL_URI, payload) if not res.get("success"): @@ -1043,12 +1009,6 @@ def get_all_dashboards(self, type: str, page: Optional[int] = 1) -> pd.DataFrame :return: Result DataFrame """ - Validate.value_against_list( - "type", - type, - DASHBOARD_TYPES, - ) - res = self.api_client.get( f"{DASHBOARD_LOGS_URI}?project_name={self.project_name}&type={type}&page={page}", ) @@ -1088,11 +1048,6 @@ def get_dashboard_metadata(self, type: str, dashboard_id: str) -> Dashboard: - print_config(): Pretty-print the dashboard configuration :rtype: Dashboard """ - Validate.value_against_list( - "type", - type, - DASHBOARD_TYPES, - ) res = self.api_client.get( f"{GET_DASHBOARD_URI}?type={type}&project_name={self.project_name}&dashboard_id={dashboard_id}" @@ -1118,11 +1073,6 @@ def get_dashboard(self, type: str, dashboard_id: str) -> Dashboard: - print_config(): Pretty-print the dashboard configuration :rtype: Dashboard """ - Validate.value_against_list( - "type", - type, - DASHBOARD_TYPES, - ) res = self.api_client.get( f"{GET_DASHBOARD_URI}?type={type}&project_name={self.project_name}&dashboard_id={dashboard_id}" @@ -1249,15 +1199,6 @@ def create_monitor(self, payload: dict) -> str: """ payload["project_name"] = self.project_name - required_payload_keys = [ - "trigger_type", - "priority", - "mail_list", - "frequency", - "trigger_name", - ] - - Validate.check_for_missing_keys(payload, required_payload_keys) if payload.get("pod", None): payload["instance_type"] = payload["pod"] payload = { @@ -1699,7 +1640,7 @@ def data_observations(self, tag: str) -> pd.DataFrame: raise Exception("Data summary not available, please upload data first.") if tag not in valid_tags: - raise Exception(f"Not a vaild tag. Pick a valid tag from {valid_tags}") + raise Exception(f"Not a valid tag. Pick a valid tag from {valid_tags}") data = { "Total Data Volume": res["data"]["overview"]["Total Data Volumn"], @@ -1756,21 +1697,6 @@ def data_drift_diagnosis( """ if baseline_tags and current_tags: - if pod not in [ - "small", - "xsmall", - "2xsmall", - "3xsmall", - "medium", - "xmedium", - "2xmedium", - "3xmedium", - "large", - "xlarge", - "2xlarge", - "3xlarge", - ]: - return "pod is not valid. Valid types are small, xsmall, 2xsmall, 3xsmall, medium, xmedium, 2xmedium, 3xmedium, large, xlarge, 2xlarge, 3xlarge" payload = { "project_name": self.project_name, @@ -1781,10 +1707,10 @@ def data_drift_diagnosis( res = self.api_client.post(RUN_DATA_DRIFT_DIAGNOSIS_URI, payload) if not res["success"]: - if res.get("details").get("reason"): - raise Exception(res.get("details").get("reason")) - else: - raise Exception(res.get("message")) + details = res.get("details") + if isinstance(details, dict) and details.get("reason"): + raise Exception(details.get("reason")) + raise Exception(details or res.get("message")) poll_events(self.api_client, self.project_name, res["task_id"]) res = self.api_client.post( @@ -1872,39 +1798,6 @@ def get_data_drift_dashboard( payload["project_name"] = self.project_name - # validate required fields - Validate.check_for_missing_keys(payload, DATA_DRIFT_DASHBOARD_REQUIRED_FIELDS) - - # validate tags and labels - tags_info = self.available_tags() - all_tags = tags_info["alltags"] - - Validate.value_against_list("base_line_tag", payload["base_line_tag"], all_tags) - Validate.value_against_list("current_tag", payload["current_tag"], all_tags) - - Validate.validate_date_feature_val(payload, tags_info["alldatetimefeatures"]) - - if payload.get("features_to_use"): - Validate.value_against_list( - "features_to_use", - payload.get("features_to_use", []), - tags_info["alluniquefeatures"], - ) - - Validate.value_against_list( - "stat_test_name", payload["stat_test_name"], DATA_DRIFT_STAT_TESTS - ) - - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - Validate.value_against_list( - "pod", - pod, - [ - server["instance_name"] - for server in custom_batch_servers.get("details", []) - ], - ) - if payload.get("pod", None): payload["instance_type"] = payload["pod"] if pod: @@ -1987,46 +1880,6 @@ def get_target_drift_dashboard( payload["project_name"] = self.project_name - # validate required fields - Validate.check_for_missing_keys(payload, TARGET_DRIFT_DASHBOARD_REQUIRED_FIELDS) - - # validate tags and labels - tags_info = self.available_tags() - all_tags = tags_info["alltags"] - - Validate.value_against_list("base_line_tag", payload["base_line_tag"], all_tags) - Validate.value_against_list("current_tag", payload["current_tag"], all_tags) - - Validate.validate_date_feature_val(payload, tags_info["alldatetimefeatures"]) - - Validate.value_against_list("model_type", payload["model_type"], MODEL_TYPES) - - Validate.value_against_list( - "stat_test_name", payload["stat_test_name"], TARGET_DRIFT_STAT_TESTS - ) - - Validate.value_against_list( - "baseline_true_label", - [payload["baseline_true_label"]], - tags_info["alluniquefeatures"], - ) - - Validate.value_against_list( - "current_true_label", - [payload["current_true_label"]], - tags_info["alluniquefeatures"], - ) - - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - Validate.value_against_list( - "pod", - pod, - [ - server["instance_name"] - for server in custom_batch_servers.get("details", []) - ], - ) - if payload.get("pod", None): payload["instance_type"] = payload["pod"] if pod: @@ -2080,50 +1933,6 @@ def get_bias_monitoring_dashboard( payload["project_name"] = self.project_name - # validate required fields - Validate.check_for_missing_keys( - payload, BIAS_MONITORING_DASHBOARD_REQUIRED_FIELDS - ) - - # validate tags and labels - tags_info = self.available_tags() - all_tags = tags_info["alltags"] - - Validate.value_against_list("base_line_tag", payload["base_line_tag"], all_tags) - - Validate.validate_date_feature_val(payload, tags_info["alldatetimefeatures"]) - - Validate.value_against_list("model_type", payload["model_type"], MODEL_TYPES) - - Validate.value_against_list( - "baseline_true_label", - [payload["baseline_true_label"]], - tags_info["alluniquefeatures"], - ) - - Validate.value_against_list( - "baseline_pred_label", - [payload["baseline_pred_label"]], - tags_info["alluniquefeatures"], - ) - - if payload.get("features_to_use"): - Validate.value_against_list( - "features_to_use", - payload.get("features_to_use", []), - tags_info["alluniquefeatures"], - ) - - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - Validate.value_against_list( - "pod", - pod, - [ - server["instance_name"] - for server in custom_batch_servers.get("details", []) - ], - ) - if payload.get("pod", None): payload["instance_type"] = payload["pod"] if pod: @@ -2179,61 +1988,6 @@ def get_model_performance_dashboard( payload["project_name"] = self.project_name - tags_info = self.available_tags() - all_tags = tags_info["alltags"] - - if self.metadata.get("modality") == "image": - Validate.check_for_missing_keys(payload, ["base_line_tag", "current_tag"]) - - Validate.value_against_list("base_line_tag", payload["base_line_tag"], all_tags) - Validate.value_against_list("current_tag", payload["current_tag"], all_tags) - - if self.metadata.get("modality") == "tabular": - Validate.check_for_missing_keys( - payload, MODEL_PERF_DASHBOARD_REQUIRED_FIELDS - ) - Validate.validate_date_feature_val( - payload, tags_info["alldatetimefeatures"] - ) - - Validate.value_against_list( - "model_type", payload["model_type"], MODEL_TYPES - ) - - Validate.value_against_list( - "baseline_true_label", - [payload["baseline_true_label"]], - tags_info["alluniquefeatures"], - ) - - Validate.value_against_list( - "baseline_pred_label", - [payload["baseline_pred_label"]], - tags_info["alluniquefeatures"], - ) - - Validate.value_against_list( - "current_true_label", - [payload["current_true_label"]], - tags_info["alluniquefeatures"], - ) - - Validate.value_against_list( - "current_pred_label", - [payload["current_pred_label"]], - tags_info["alluniquefeatures"], - ) - - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - Validate.value_against_list( - "pod", - pod, - [ - server["instance_name"] - for server in custom_batch_servers.get("details", []) - ], - ) - if payload.get("pod", None): payload["instance_type"] = payload["pod"] if pod: @@ -2260,11 +2014,6 @@ def get_dashboard_log_data(self, type: str): :return: DataFrame :rtype: pd.DataFrame """ - Validate.value_against_list( - "type", - type, - DASHBOARD_TYPES, - ) self.api_client.refresh_bearer_token() auth_token = self.api_client.get_auth_token() query_params = ( @@ -2275,11 +2024,13 @@ def get_dashboard_log_data(self, type: str): res = self.api_client.base_request("GET", uri) if res.status_code != 200: - raise Exception( - res.get( - "details", f"Error Downloading Dasboard Logs, {res.status_code}" + try: + details = res.json().get( + "details", f"Error Downloading Dashboard Logs, {res.status_code}" ) - ) + except Exception: + details = f"Error Downloading Dashboard Logs, {res.status_code}" + raise Exception(details) try: df = pd.read_csv(io.StringIO(res.text)) @@ -2337,11 +2088,6 @@ def model_inference( models = self.models() - available_models = models["model_name"].to_list() - - if model_name: - Validate.value_against_list("model_name", model_name, available_models) - model = ( model_name or models.loc[models["status"] == "active"]["model_name"].values[0] @@ -2575,158 +2321,11 @@ def train_model( if project_config == "Not Found": raise Exception("Upload files first") - available_models = self.available_models() - - Validate.value_against_list("model_type", model_type, available_models) - all_unique_features = [ *project_config["metadata"]["feature_exclude"], *project_config["metadata"]["feature_include"], ] - if tunning_strategy != "inference" and compute_type and "gova" not in compute_type: - custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI) - available_custom_batch_servers = custom_batch_servers.get("details", []) + custom_batch_servers.get("available_gpu_custom_servers", []) - Validate.value_against_list( - "pod", - compute_type, - [ - server["instance_name"] - for server in available_custom_batch_servers - ], - ) - - if data_config: - if data_config.get("feature_exclude"): - Validate.value_against_list( - "feature_exclude", - data_config["feature_exclude"], - all_unique_features, - ) - - if data_config.get("tags"): - available_tags = self.tags() - Validate.value_against_list("tags", data_config["tags"], available_tags) - - if data_config.get("test_tags"): - available_tags = self.tags() - Validate.value_against_list( - "test_tags", data_config["test_tags"], available_tags - ) - - if data_config.get("feature_encodings"): - Validate.value_against_list( - "feature_encodings_feature", - list(data_config["feature_encodings"].keys()), - list(project_config["metadata"]["feature_encodings"].keys()), - ) - Validate.value_against_list( - "feature_encodings_feature", - list(data_config["feature_encodings"].values()), - ["labelencode", "countencode", "onehotencode"], - ) - - if data_config.get("sample_percentage"): - if ( - data_config["sample_percentage"] < 0 - or data_config["sample_percentage"] > 1 - ): - raise Exception( - "Data sample percentage is invalid, select between 0 and 1" - ) - - if data_config.get("explainability_sample_percentage"): - if ( - data_config["explainability_sample_percentage"] < 0 - or data_config["explainability_sample_percentage"] > 1 - ): - raise Exception( - "Explainability sample percentage is invalid, select between 0 and 1" - ) - - if data_config.get("lime_explainability_iterations"): - if ( - data_config["lime_explainability_iterations"] < 1 - or data_config["lime_explainability_iterations"] > 10000 - ): - raise Exception( - "Lime explainability iterations is invalid, select between 1 and 10000" - ) - - if data_config.get("xai_method"): - Validate.value_against_list( - "xai_method", - data_config["xai_method"], - ["shap", "lime"], - ) - - if model_config: - model_params = self.api_client.get(MODEL_PARAMETERS_URI) - model_name = f"{model_type}_{project_config['project_type']}".lower() - model_parameters = model_params.get(model_name) - - if model_parameters: - - def validate_params(param_group, config_group): - """Validate config values against model parameter constraints. - Checks select options and numeric min/max bounds, raising exceptions on invalid values. - - :param param_group: Parameter definition dict (select/input types with constraints). - :param config_group: User-supplied config dict to validate against `param_group`. - :raises Exception: If any value violates the declared constraints. - """ - if config_group: - for param_name, param_value in config_group.items(): - model_param = param_group.get(param_name) - if not model_param: - # raise Exception( - # f"Invalid model config for {model_type} \n {json.dumps(model_parameters)}" - # ) - continue - - param_type = model_param["type"] - - if param_type == "select": - Validate.value_against_list( - param_name, param_value, model_param["value"] - ) - elif param_type == "input": - if param_value > model_param["max"]: - raise Exception( - f"{param_name} value cannot be greater than {model_param['max']}" - ) - if param_value < model_param["min"]: - raise Exception( - f"{param_name} value cannot be less than {model_param['min']}" - ) - - if model_type in ["TabPFN","TabICL","TabDPT","OrionMSP", "OrionBix","Mitra", "ContextTab"]: - validate_params( - model_parameters.get("model_params", {}), model_config - ) - validate_params( - model_parameters.get("tunning_params", {}), tunning_config - ) - validate_params( - model_parameters.get("processor_params", {}), processor_config - ) - validate_params( - model_parameters.get("peft_params", {}), peft_config - ) - else: - validate_params(model_parameters, model_config) - if finetune_mode: - Validate.value_against_list( - "finetune_mode", - finetune_mode, - ["meta-learning", "sft"], - ) - if tunning_strategy: - Validate.value_against_list( - "tunning_strategy", - tunning_strategy, - ["base-ft", "inference", "peft", "finetune"], - ) data_conf = data_config or {} feature_exclude = [ @@ -2748,10 +2347,10 @@ def validate_params(param_group, config_group): ) explainability_method = ( - data_conf.get("explainability_method") + data_conf.get("explainability_method") or data_conf.get("xai_method") - or project_config.get("metadata", {}).get("xai_method") - or project_config.get("metadata", {}).get("explainability_method") + or project_config.get("metadata", {}).get("xai_method") + or project_config.get("metadata", {}).get("explainability_method") ) tags = data_conf.get("tags") or project_config["metadata"]["tags"] @@ -2969,9 +2568,6 @@ def delete_cases( :param tag: tag of case, defaults to None :return: response """ - if tag: - all_tags = self.all_tags() - Validate.value_against_list("tag", tag, all_tags) paylod = { "project_name": self.project_name, @@ -3058,19 +2654,11 @@ def model_inference_settings( server_config = inference_compute.get("custom_server_config", {}) server_config["start"] = normalize_time(server_config.get("start")) server_config["stop"] = normalize_time(server_config.get("stop")) - if server_config["start"] and not server_config["stop"]: - raise ValueError("If start is provided, stop cannot be None.") - - if server_config["stop"] and not server_config["start"]: - raise ValueError("If stop is provided, start cannot be None.") if server_config["start"] and server_config["stop"]: start_dt = datetime.fromisoformat(server_config["start"]) stop_dt = datetime.fromisoformat(server_config["stop"]) - if stop_dt - start_dt < timedelta(minutes=15): - raise ValueError("Stop time must be at least 15 minutes greater than start time.") - if not server_config.get("op_hours") and server_config.get("auto_start"): server_config["op_hours"] = True @@ -3254,29 +2842,10 @@ def create_observation( :param linked_features: linked features of observation :return: response """ - observation_params = self.api_client.get( - f"{GET_OBSERVATION_PARAMS_URI}?project_name={self.project_name}" - ) - Validate.string("expression", expression) - Validate.string("statement", statement) - - Validate.value_against_list( - "linked_feature", - linked_features, - list(observation_params["details"]["features"].keys()), - ) configuration, expression = build_expression(expression) - validate_configuration( - configuration, - observation_params["details"], - self.project_name, - self.api_client, - True, - ) - payload = { "project_name": self.project_name, "observation_name": observation_name, @@ -3321,8 +2890,6 @@ def update_observation( :param linked_features: new linked features for observation, defaults to None :return: response """ - if not status and not expression and not statement and not linked_features: - raise Exception("update parameters for observation not passed") payload = { "project_name": self.project_name, @@ -3331,36 +2898,19 @@ def update_observation( "update_keys": {}, } - observation_params = self.api_client.get( - f"{GET_OBSERVATION_PARAMS_URI}?project_name={self.project_name}" - ) - if expression: Validate.string("expression", expression) configuration, expression = build_expression(expression) - validate_configuration( - configuration, - observation_params["details"], - self.project_name, - self.api_client, - ) payload["update_keys"]["configuration"] = configuration payload["update_keys"]["metadata"] = {"expression": expression} if linked_features: - Validate.value_against_list( - "linked_feature", - linked_features, - list(observation_params["details"]["features"].keys()), - ) payload["update_keys"]["linked_features"] = linked_features if statement: - Validate.string("statement", statement) payload["update_keys"]["statement"] = [statement] if status: - Validate.value_against_list("status", status, ["active", "inactive"]) payload["update_keys"]["status"] = status res = self.api_client.post(UPDATE_OBSERVATION_URI, payload) @@ -3573,21 +3123,6 @@ def create_policy( """ configuration, expression = build_expression(expression) - policy_params = self.api_client.get( - f"{GET_POLICY_PARAMS_URI}?project_name={self.project_name}" - ) - - validate_configuration( - configuration, policy_params["details"], self.project_name, self.api_client - ) - - Validate.value_against_list( - "decision", decision, list(policy_params["details"]["decision"].values())[0] - ) - - if decision == "input": - Validate.string("Decision input", input) - payload = { "project_name": self.project_name, "policy_name": policy_name, @@ -3640,8 +3175,6 @@ def update_policy( :param priority: Priority of the policy. Lower number indicates higher priority. Defaults to 5 :return: response """ - if not status and not expression and not statement and not decision: - raise Exception("update parameters for policy not passed") payload = { "project_name": self.project_name, @@ -3650,38 +3183,19 @@ def update_policy( "update_keys": {}, } - policy_params = self.api_client.get( - f"{GET_POLICY_PARAMS_URI}?project_name={self.project_name}" - ) - if expression: Validate.string("expression", expression) configuration, expression = build_expression(expression) - validate_configuration( - configuration, - policy_params["details"], - self.project_name, - self.api_client, - ) payload["update_keys"]["configuration"] = configuration payload["update_keys"]["metadata"] = {"expression": expression} if statement: - Validate.string("statement", statement) payload["update_keys"]["statement"] = [statement] if status: - Validate.value_against_list("status", status, ["active", "inactive"]) payload["update_keys"]["status"] = status if decision: - Validate.value_against_list( - "decision", - decision, - list(policy_params["details"]["decision"].values())[0], - ) - if decision == "input": - Validate.string("Decision input", input) payload["update_keys"]["decision"] = ( input if decision == "input" else decision ) @@ -3787,42 +3301,20 @@ def train_synthetic_model( servers = list( map(lambda instance: instance["instance_name"], available_servers) ) - Validate.value_against_list("instance_type", pod, servers) - - all_models_param = self.get_synthetic_model_params() - - try: - model_params = all_models_param[model_name] - except KeyError as e: - availabel_models = list(all_models_param.keys()) - Validate.value_against_list("model", [model_name], availabel_models) # validate and prepare data config data_config["model_name"] = model_name - available_tags = self.tags() - tags = data_config.get("tags", available_tags) - - Validate.value_against_list("tag", tags, available_tags) + tags = data_config.get("tags", []) feature_exclude = data_config.get( "feature_exclude", project_config["feature_exclude"] ) - Validate.value_against_list( - "feature_exclude", feature_exclude, project_config["avaialble_options"] - ) - feature_include = data_config.get( "feature_include", project_config["feature_include"] ) - Validate.value_against_list( - "feature_include", - feature_include, - project_config["avaialble_options"], - ) - drop_duplicate_uid = data_config.get( "drop_duplicate_uid", project_config["drop_duplicate_uid"] ) @@ -3830,28 +3322,6 @@ def train_synthetic_model( SYNTHETIC_MODELS_DEFAULT_HYPER_PARAMS[model_name].update(hyper_params) hyper_params = SYNTHETIC_MODELS_DEFAULT_HYPER_PARAMS[model_name] - # validate model hyper parameters - for key, value in hyper_params.items(): - model_param = model_params.get(key, None) - - if model_param: - if model_param["type"] == "input": - if model_param["value"] == "int": - if not isinstance(value, int): - raise Exception(f"{key} value should be integer") - elif model_param["value"] == "float": - if not isinstance(value, float): - raise Exception(f"{key} value should be float") - - if value < model_param["min"] or value > model_param["max"]: - raise Exception( - f"{key} value should be between {model_param['min']} and {model_param['max']}" - ) - elif model_param["type"] == "select": - Validate.value_against_list( - "value", [value], model_param["value"] - ) - print(f"Using data config: {json.dumps(data_config, indent=4)}") print(f"Using hyper params: {json.dumps(hyper_params, indent=4)}") @@ -3886,13 +3356,6 @@ def remove_synthetic_model(self, model_name: str) -> str: :raises Exception: _description_ :return: response message """ - models_df = self.synthetic_models() - valid_models = models_df["model_name"].tolist() - - if model_name not in valid_models: - raise ValueError( - f"{model_name} is not valid. Pick a valid value from {valid_models}" - ) payload = {"project_name": self.project_name, "model_name": model_name} @@ -3928,13 +3391,6 @@ def synthetic_model(self, model_name: str) -> SyntheticModel: :raises Exception: _description_ :return: _description_ """ - models_df = self.synthetic_models() - valid_models = models_df["model_name"].tolist() - - if model_name not in valid_models: - raise ValueError( - f"{model_name} is not valid. Pick a valid value from {valid_models}" - ) url = f"{GET_SYNTHETIC_MODEL_DETAILS_URI}?project_name={self.project_name}&model_name={model_name}" @@ -4024,13 +3480,6 @@ def synthetic_tag_datapoints(self, tag: str) -> pd.DataFrame: :raises Exception: _description_ :return: datapoints """ - all_tags = self.all_tags() - - Validate.value_against_list( - "tag", - tag, - all_tags, - ) res = self.api_client.base_request( "GET", @@ -4048,13 +3497,6 @@ def remove_synthetic_tag(self, tag: str) -> str: :raises Exception: _description_ :return: response messsage """ - all_tags = self.all_tags() - - Validate.value_against_list( - "tag", - tag, - all_tags, - ) payload = { "project_name": self.project_name, @@ -4098,11 +3540,6 @@ def create_synthetic_prompt(self, name: str, expression: str) -> str: configuration, expression = build_expression(expression) - prompt_params = self.get_observation_params() - validate_configuration( - configuration, prompt_params, self.project_name, self.api_client - ) - payload = { "prompt_name": name, "project_name": self.project_name,