Skip to content

Fix custom training scripts that define train()#23

Open
Kaladin-Stormborn wants to merge 1 commit into
kubeflow:mainfrom
Kaladin-Stormborn:fix-custom-training-uncalled-train
Open

Fix custom training scripts that define train()#23
Kaladin-Stormborn wants to merge 1 commit into
kubeflow:mainfrom
Kaladin-Stormborn:fix-custom-training-uncalled-train

Conversation

@Kaladin-Stormborn
Copy link
Copy Markdown

PR Draft: #19

Title

Fix custom training scripts that define an uncalled train()

Summary

This fixes run_custom_training script wrapping when a user provides a script that already defines a top-level train() function but does not call it.

Previously _make_train_func() always wrapped the submitted script inside another generated def train():. For scripts like:

def train():
    print("Training started")

the user-defined train() became a nested function and was never invoked when the generated wrapper ran.

The patch detects a top-level def train() with no module-level train() call and appends a call inside the generated wrapper. Scripts that already call train() are left alone to avoid double execution. Matching func_args are forwarded to a user-defined train(...) signature.

Tests

uv sync --extra dev
.venv/bin/ruff check kubeflow_mcp/trainer/api/training.py kubeflow_mcp/trainer/api/sdk_contracts_test.py
.venv/bin/python -m pytest kubeflow_mcp/trainer/api/sdk_contracts_test.py -q
.venv/bin/python -m pytest kubeflow_mcp/trainer/api/sdk_contracts_test.py tests/unit/trainer/test_architecture.py -q
.venv/bin/python -m pytest -q

Results:

  • Ruff: passed.
  • SDK contract test: 74 passed.
  • Full local suite: 139 passed, 3 upstream dependency deprecation warnings.

Current diff size:

  • Implementation: 59 insertions, 1 deletion.
  • Tests: 84 insertions.
  • Total: 143 insertions, 1 deletion.

Changed Files

  • kubeflow_mcp/trainer/api/training.py
  • kubeflow_mcp/trainer/api/sdk_contracts_test.py

@google-oss-prow
Copy link
Copy Markdown

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign kramaranya for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@abhijeet-dhumal
Copy link
Copy Markdown
Member

/ok-to-test

@abhijeet-dhumal
Copy link
Copy Markdown
Member

Hey @Kaladin-Stormborn, can you sign off your commit here ⬆️
git commit -s

Signed-off-by: Kaladin-Stormborn <Kaladin-Stormborn@users.noreply.github.com>
@abhijeet-dhumal
Copy link
Copy Markdown
Member

Thanks @Kaladin-Stormborn for this fix 🚀
Just adding some nit picks

return ns[func_name]


def _uncalled_train_call(script: str, func_args: dict[str, Any] | None = None) -> str | None:
Copy link
Copy Markdown
Member

@abhijeet-dhumal abhijeet-dhumal May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): thinking can we consider treating top-level async def train() like def train() for the append path (ast.AsyncFunctionDef). If we append a call, it likely needs await train(...) inside the generated sync wrapper (or document that only sync def train() is supported for now).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants