Rename EarlyStoppingStrategy → ArmStoppingStrategy with arm-level decisions#5134
Open
lena-kashtelyan wants to merge 1 commit intofacebook:mainfrom
Open
Rename EarlyStoppingStrategy → ArmStoppingStrategy with arm-level decisions#5134lena-kashtelyan wants to merge 1 commit intofacebook:mainfrom
lena-kashtelyan wants to merge 1 commit intofacebook:mainfrom
Conversation
…isions Summary: # Context The current `EarlyStoppingStrategy` was built for single-arm `Trial`s only — it returns `dict[int, str | None]` (trial index → reason) and explicitly rejects `BatchTrial`s. This diff extends it to support `BatchTrial`s by making stopping decisions at the arm level. But we are going to need batch-level stopping. # Key design decision: `ArmStoppingStrategy` or `TrialStoppingStrategy`? [in Ax TLs sync, we decided we liked `ArmStoppingStrategy` One choice we're going to need to make is whether we want arm-level and trial-level stopping strategies to be the same or separate. The use cases for arm stopping will have to do with safety, constraint violations etc. The use cases for stopping trials will have more to do with normal orchestration, e.g. we need to stop a trial in order to run a new one. # Next step: add `Runner.stop_arm` and `Runner.stop_trial` I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. # Another likely next step: make GS and ESS decisions jointly I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. I know not everyone loves this idea; let's discuss. Related discussion about a use case here: https://docs.google.com/document/d/19K3kBXX9c5WUIUu_t_gC9KZkeAMo_EffSymKjaPTSyI/edit?tab=t.0, re-pasting for convenience: - Lena [on whether an experiment that does not yet do any generation and only stopping]: My preferred design would be that we use a GNode to fit the model in GS, then the ESS uses this, but at the moment ESS is applied first, then GS (in Orchestrator and thus Axolotl). I think this current order is right as long as we don't merge ESS and GS (which I'd like to do eventually, including for reasons like this). So what we can do for now is just put a GNode within an ESS, then worry about the rest later. And we can just have an empty GS for now to keep the Orchestrator happy. - Sam: Yeah we could do that. Calling ESS before GS in the orchestrator fundamentally seems like the wrong order as we move toward model-based early stopping (e.g., in the conductor case). I wonder if we should merge GS and ESS sooner rather than later to resolve that, rather than create some tech debt by having ESS fit its own adapter. Eventually we decided that we would like to (later this year) merge ESS and GS, such that the actual cycle is `gen` --> cache results --> compare ROI on new vs. running trial --> make decision on stopping and running jointly. We thought some kind of `DecisionNode` might do this: {F1987649989} ---- # Claude stuff below Key changes: - Rename `BaseEarlyStoppingStrategy` → `BaseArmStoppingStrategy` (with backward-compat alias) - Rename `ModelBasedEarlyStoppingStrategy` → `ModelBasedArmStoppingStrategy` (with alias) - Change return type of `should_stop_trials_early` / `_should_stop_trials_early` from `dict[int, str | None]` → `dict[int, dict[str, str | None]]` (trial_index → {arm_name → reason}) - Remove `BatchTrial` rejection check in `is_eligible_any` - Add `_wrap_trial_results_with_arms()` helper to convert trial-level decisions to arm-level format - Update all subclasses (percentile, threshold, logical, multi_objective, quickbo) to use the new return type - Update orchestrator to check if all arms are stopped before stopping a trial (raises `NotImplementedError` for partial arm stopping) - Update `ax_client`, `api/client`, and `internal_client` to extract reasons from arm-level dict - Update all tests Differential Revision: D97304068
|
@lena-kashtelyan has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97304068. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #5134 +/- ##
==========================================
- Coverage 96.40% 96.40% -0.01%
==========================================
Files 613 613
Lines 68171 68216 +45
==========================================
+ Hits 65721 65764 +43
- Misses 2450 2452 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Context
The current
EarlyStoppingStrategywas built for single-armTrials only — it returnsdict[int, str | None](trial index → reason) and explicitly rejectsBatchTrials. This diff extends it to supportBatchTrials by making stopping decisions at the arm level. But we are going to need batch-level stopping.Key design decision:
ArmStoppingStrategyorTrialStoppingStrategy? [in Ax TLs sync, we decided we likedArmStoppingStrategyOne choice we're going to need to make is whether we want arm-level and trial-level stopping strategies to be the same or separate. The use cases for arm stopping will have to do with safety, constraint violations etc. The use cases for stopping trials will have more to do with normal orchestration, e.g. we need to stop a trial in order to run a new one.
Next step: add
Runner.stop_armandRunner.stop_trialI think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too.
Another likely next step: make GS and ESS decisions jointly
I think that the two will be separate bc they'll often entail different logic at their respective backends. How we choose to do these may impact how we choose to do the stopping strategy, too. I know not everyone loves this idea; let's discuss.
Related discussion about a use case here: https://docs.google.com/document/d/19K3kBXX9c5WUIUu_t_gC9KZkeAMo_EffSymKjaPTSyI/edit?tab=t.0, re-pasting for convenience:
Eventually we decided that we would like to (later this year) merge ESS and GS, such that the actual cycle is
gen--> cache results --> compare ROI on new vs. running trial --> make decision on stopping and running jointly. We thought some kind ofDecisionNodemight do this:{F1987649989}
Claude stuff below
Key changes:
BaseEarlyStoppingStrategy→BaseArmStoppingStrategy(with backward-compat alias)ModelBasedEarlyStoppingStrategy→ModelBasedArmStoppingStrategy(with alias)should_stop_trials_early/_should_stop_trials_earlyfromdict[int, str | None]→dict[int, dict[str, str | None]](trial_index → {arm_name → reason})BatchTrialrejection check inis_eligible_any_wrap_trial_results_with_arms()helper to convert trial-level decisions to arm-level formatNotImplementedErrorfor partial arm stopping)ax_client,api/client, andinternal_clientto extract reasons from arm-level dictDifferential Revision: D97304068