From 5a6358f4b61f64334113186f35cdd0996884412e Mon Sep 17 00:00:00 2001 From: Khush Date: Sat, 20 Jun 2026 11:05:41 -0400 Subject: [PATCH 1/6] feat(Aggregation): Add ExcessMTLWeighting --- CHANGELOG.md | 1 + NOTICES | 28 +++ docs/source/docs/aggregation/excess_mtl.rst | 7 + docs/source/docs/aggregation/index.rst | 1 + src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_excess_mtl.py | 184 +++++++++++++++++ tests/unit/aggregation/test_excess_mtl.py | 209 ++++++++++++++++++++ 7 files changed, 432 insertions(+) create mode 100644 docs/source/docs/aggregation/excess_mtl.rst create mode 100644 src/torchjd/aggregation/_excess_mtl.py create mode 100644 tests/unit/aggregation/test_excess_mtl.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 52107b54..287b3b14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ changelog does not include internal changes that do not affect the user. Algorithm Based on Decomposition](https://ieeexplore.ieee.org/document/4358754) (IEEE TEVC 2007), a `Scalarizer` that decomposes the values into a component along a preference direction and a penalized perpendicular component. +- Added `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin. ## [0.15.0] - 2026-06-15 diff --git a/NOTICES b/NOTICES index 098695c3..0f1ac03a 100644 --- a/NOTICES +++ b/NOTICES @@ -143,6 +143,34 @@ SOFTWARE. ------------------------------------------------------------------------------- +Project: ExcessMTL +Source: https://github.com/uiuctml/ExcessMTL/blob/main/LibMTL/LibMTL/weighting/ExcessMTL.py +Used in: src/torchjd/aggregation/_excess_mtl.py + +MIT License + +Copyright (c) 2024 UIUC TML Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +------------------------------------------------------------------------------- + Project: SDMGrad Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py Used in: src/torchjd/aggregation/_sdmgrad.py diff --git a/docs/source/docs/aggregation/excess_mtl.rst b/docs/source/docs/aggregation/excess_mtl.rst new file mode 100644 index 00000000..e79fcc88 --- /dev/null +++ b/docs/source/docs/aggregation/excess_mtl.rst @@ -0,0 +1,7 @@ +:hide-toc: + +ExcessMTL +========= + +.. autoclass:: torchjd.aggregation.ExcessMTLWeighting + :members: __call__, reset diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 867e8d8c..b77332db 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -30,6 +30,7 @@ Abstract base classes constant.rst cr_mogm.rst dualproj.rst + excess_mtl.rst fairgrad.rst graddrop.rst gradvac.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 5285e6bf..61cae873 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -45,6 +45,7 @@ from ._constant import Constant, ConstantWeighting from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting +from ._excess_mtl import ExcessMTLWeighting from ._fairgrad import FairGrad, FairGradWeighting from ._graddrop import GradDrop from ._gradvac import GradVac, GradVacWeighting @@ -74,6 +75,7 @@ "CRMOGMWeighting", "DualProj", "DualProjWeighting", + "ExcessMTLWeighting", "FairGrad", "FairGradWeighting", "GradDrop", diff --git a/src/torchjd/aggregation/_excess_mtl.py b/src/torchjd/aggregation/_excess_mtl.py new file mode 100644 index 00000000..266fe25c --- /dev/null +++ b/src/torchjd/aggregation/_excess_mtl.py @@ -0,0 +1,184 @@ +# Partly adapted from https://github.com/uiuctml/ExcessMTL — MIT License, Copyright (c) 2024 UIUC TML Lab. +# See NOTICES for the full license text. +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd._mixins import Stateful +from torchjd.aggregation._mixins import _NonDifferentiable +from torchjd.linalg import Matrix + +from ._weighting_bases import _MatrixWeighting + + +class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.Stateful` + :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Robust + Multi-Task Learning with Excess Risks + `_ (ICML 2024). + + At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven + by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a + second-order Taylor expansion (Equations 6-7): + + :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update. + Must be positive. + :param n_warmup_steps: Number of forward calls during which weights stay uniform + (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess + risk is set to the average excess risk observed during warmup. When ``0`` (default), the + first call's excess risk is used as the baseline and weights are updated immediately + (matching the official implementation). + + .. warning:: + The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients + across **all** calls, where :math:`n` is the total number of model parameters. For large + models this can be a significant memory cost. Call :meth:`reset` between experiments. + + .. note:: + The weight update is adapted from the `official implementation + `_ and `LibMTL + `_. + The warmup strategy follows Appendix C.1 of the paper, which recommends collecting + gradient statistics for several epochs before beginning weight updates; set + ``n_warmup_steps`` accordingly (e.g. ``3 * len(dataloader)``). + + .. admonition:: Example + + .. testcode:: + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd import autojac + from torchjd.aggregation import ExcessMTLWeighting, WeightedAggregator + from torchjd.autojac import jac_to_grad + + inputs = torch.randn(8, 5) + targets = torch.randn(8, 2) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 2)) + optimizer = SGD(model.parameters()) + criterion = MSELoss() + aggregator = WeightedAggregator(ExcessMTLWeighting()) + + outputs = model(inputs) + losses = [criterion(outputs[:, i], targets[:, i]) for i in range(2)] + autojac.backward(losses) + jac_to_grad(model.parameters(), aggregator) + optimizer.step() + optimizer.zero_grad() + """ + + def __init__( + self, + robust_step_size: float = 1.0, + n_warmup_steps: int = 0, + ) -> None: + super().__init__() + self.robust_step_size = robust_step_size + self.n_warmup_steps = n_warmup_steps + self.register_buffer("_weights", None) + self.register_buffer("_grad_sum", None) + self.register_buffer("_initial_w", None) + self.register_buffer("_warmup_w_sum", None) + self.register_buffer("_n_steps", torch.zeros((), dtype=torch.long)) + self._state_key: tuple[int, int, torch.dtype, torch.device] | None = None + + @property + def robust_step_size(self) -> float: + return self._robust_step_size + + @robust_step_size.setter + def robust_step_size(self, value: float) -> None: + if value <= 0.0: + raise ValueError( + f"Attribute `robust_step_size` must be positive. Found robust_step_size={value!r}." + ) + self._robust_step_size = value + + @property + def n_warmup_steps(self) -> int: + return self._n_warmup_steps + + @n_warmup_steps.setter + def n_warmup_steps(self, value: int) -> None: + if value < 0: + raise ValueError( + f"Attribute `n_warmup_steps` must be non-negative. Found n_warmup_steps={value!r}." + ) + self._n_warmup_steps = value + + def reset(self) -> None: + """Clears all state so the next forward starts from uniform weights and re-enters + warmup.""" + + self._weights = None + self._grad_sum = None + self._initial_w = None + self._warmup_w_sum = None + self._n_steps.zero_() + self._state_key = None + + def forward(self, matrix: Matrix, /) -> Tensor: + self._ensure_state(matrix) + + # Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7) + grad_sum = cast(Tensor, self._grad_sum) + grad_sum = grad_sum + matrix.detach() ** 2 + self._grad_sum = grad_sum + + # Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6) + h = torch.sqrt(grad_sum + 1e-7) + w = (matrix.detach() ** 2 / h).sum(dim=1) # shape [m] + + n_steps = int(self._n_steps.item()) + self._n_steps = self._n_steps + 1 + + # Warmup: collect excess risk stats but return uniform weights + if n_steps < self._n_warmup_steps: + warmup_w_sum = self._warmup_w_sum + self._warmup_w_sum = w if warmup_w_sum is None else cast(Tensor, warmup_w_sum) + w + return cast(Tensor, self._weights) + + # Set baseline on the first non-warmup call + if self._initial_w is None: + if self._n_warmup_steps > 0: + # Average excess risk observed during warmup (Appendix C.1) + self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps + w = w / (cast(Tensor, self._initial_w) + 1e-7) + else: + # Official impl behaviour: first call's excess is the baseline; use w raw + self._initial_w = w + else: + w = w / (cast(Tensor, self._initial_w) + 1e-7) + + # Exponentiated gradient weight update (Equation 9) + weights = cast(Tensor, self._weights) + weights = weights * torch.exp(w * self._robust_step_size) + weights = weights / weights.sum() + self._weights = weights + return weights + + def _ensure_state(self, matrix: Matrix) -> None: + key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device) + if self._state_key == key and self._grad_sum is not None: + return + m, n = matrix.shape + self._grad_sum = matrix.new_zeros(m, n) + self._weights = matrix.new_full((m,), 1.0 / m) + self._initial_w = None + self._warmup_w_sum = None + self._n_steps.zero_() + self._state_key = key + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"robust_step_size={self.robust_step_size!r}, " + f"n_warmup_steps={self.n_warmup_steps!r})" + ) diff --git a/tests/unit/aggregation/test_excess_mtl.py b/tests/unit/aggregation/test_excess_mtl.py new file mode 100644 index 00000000..f0fcb9dc --- /dev/null +++ b/tests/unit/aggregation/test_excess_mtl.py @@ -0,0 +1,209 @@ +import torch +from pytest import raises +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation._excess_mtl import ExcessMTLWeighting + + +def test_representations() -> None: + W = ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0) + assert repr(W) == "ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0)" + + +def test_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + W = ExcessMTLWeighting() + first = W(J) + W(J) + W.reset() + assert_close(first, W(J)) + + +def test_robust_step_size_setter_accepts_valid() -> None: + W = ExcessMTLWeighting() + W.robust_step_size = 0.1 + assert W.robust_step_size == 0.1 + W.robust_step_size = 10.0 + assert W.robust_step_size == 10.0 + + +def test_robust_step_size_setter_rejects_non_positive() -> None: + W = ExcessMTLWeighting() + with raises(ValueError, match="robust_step_size"): + W.robust_step_size = 0.0 + with raises(ValueError, match="robust_step_size"): + W.robust_step_size = -1.0 + + +def test_n_warmup_steps_setter_accepts_valid() -> None: + W = ExcessMTLWeighting() + W.n_warmup_steps = 0 + assert W.n_warmup_steps == 0 + W.n_warmup_steps = 100 + assert W.n_warmup_steps == 100 + + +def test_n_warmup_steps_setter_rejects_negative() -> None: + W = ExcessMTLWeighting() + with raises(ValueError, match="n_warmup_steps"): + W.n_warmup_steps = -1 + + +def test_output_lies_on_simplex() -> None: + """The exponentiated update followed by normalisation keeps the weights on the simplex.""" + + J = randn_((4, 10)) + W = ExcessMTLWeighting() + # Call twice so the second call exercises the normalised-w branch + W(J) + weights = W(J) + assert weights.shape == (4,) + assert (weights >= 0).all() + assert_close(weights.sum(), tensor_(1.0)) + + +def test_warmup_returns_uniform() -> None: + """During warmup every call must return [1/m, ..., 1/m] regardless of the input.""" + + m, n_warmup = 3, 5 + W = ExcessMTLWeighting(n_warmup_steps=n_warmup) + expected = tensor_([1.0 / m] * m) + for _ in range(n_warmup): + assert_close(W(randn_((m, 8))), expected) + + +def test_weights_change_after_warmup() -> None: + """After warmup ends the weights must diverge from uniform when tasks have different excess risks.""" + + W = ExcessMTLWeighting(n_warmup_steps=2, robust_step_size=1.0) + # Symmetric warmup: equal excess risk for both tasks → equal initial_w + J_sym = tensor_([[1.0, 0.0], [1.0, 0.0]]) + W(J_sym) + W(J_sym) + + # Asymmetric step: task 0 has larger gradient → higher excess → weight must exceed task 1 + J_unequal = tensor_([[2.0, 0.0], [1.0, 0.0]]) + weights = W(J_unequal) + assert weights[0] > weights[1] + + +def test_update_recurrence() -> None: + """Verify the first weight update manually (n_warmup_steps=0, LibMTL behaviour). + + With J = [[2., 0.], [1., 0.]] and robust_step_size=1.0: + grad_sum = [[4., 0.], [1., 0.]] + h ≈ [[2., sqrt(eps)], [1., sqrt(eps)]] (eps = 1e-7, negligible in float32 for nonzero entries) + w = [4/2 + 0, 1/1 + 0] = [2, 1] + initial_w = [2, 1] (first call: save raw excess as baseline) + weights = [exp(2), exp(1)] / (exp(2) + exp(1)) + """ + J = tensor_([[2.0, 0.0], [1.0, 0.0]]) + W = ExcessMTLWeighting(robust_step_size=1.0) + e2 = torch.exp(tensor_(2.0)) + e1 = torch.exp(tensor_(1.0)) + assert_close(W(J), tensor_([e2 / (e2 + e1), e1 / (e2 + e1)])) + + +def test_two_consecutive_steps() -> None: + """Verify warm-started carry-over across two calls. + + Call 1: J = [[2., 0.], [1., 0.]] → weights = [e^2, e] / (e^2 + e) (from test above) + Call 2: J = [[1., 0.], [2., 0.]] + grad_sum = [[4+1., 0.], [1+4., 0.]] = [[5., 0.], [5., 0.]] + h ≈ [[sqrt(5), sqrt(eps)], [sqrt(5), sqrt(eps)]] + w = [1/sqrt(5), 4/sqrt(5)] + initial_w = [2, 1] (from call 1) + w_norm = [1/(2*sqrt(5)), 4/sqrt(5)] + weights_2 = weights_1 * [exp(w_norm_0), exp(w_norm_1)] / normalization + """ + J1 = tensor_([[2.0, 0.0], [1.0, 0.0]]) + J2 = tensor_([[1.0, 0.0], [2.0, 0.0]]) + W = ExcessMTLWeighting(robust_step_size=1.0) + + e2 = torch.exp(tensor_(2.0)) + e1 = torch.exp(tensor_(1.0)) + weights_1 = tensor_([e2 / (e2 + e1), e1 / (e2 + e1)]) + assert_close(W(J1), weights_1) + + sqrt5 = torch.sqrt(tensor_(5.0)) + w_norm_0 = tensor_(1.0) / (tensor_(2.0) * sqrt5) + w_norm_1 = tensor_(4.0) / sqrt5 + unnorm_0 = weights_1[0] * torch.exp(w_norm_0) + unnorm_1 = weights_1[1] * torch.exp(w_norm_1) + weights_2 = tensor_([unnorm_0 / (unnorm_0 + unnorm_1), unnorm_1 / (unnorm_0 + unnorm_1)]) + assert_close(W(J2), weights_2) + + +def test_warmup_baseline_is_average() -> None: + """initial_w after warmup must equal the average excess risk collected during warmup. + + With n_warmup_steps=2 and J1=[[2,0],[1,0]], J2=[[1,0],[2,0]]: + + Warmup call 1 — grad_sum_1 = J1**2 = [[4,0],[1,0]]: + h_1 ≈ [[2, sqrt(eps)], [1, sqrt(eps)]] + w_1 = [4/2, 1/1] = [2, 1] + + Warmup call 2 — grad_sum_2 = J1**2 + J2**2 = [[5,0],[5,0]]: + h_2 ≈ [[sqrt(5), sqrt(eps)], [sqrt(5), sqrt(eps)]] + w_2 = [1/sqrt(5), 4/sqrt(5)] + + initial_w = (w_1 + w_2) / 2 (Appendix C.1 average) + + Post-warmup call 3 with J3 = J1 — grad_sum_3 = [[9,0],[6,0]]: + h_3 ≈ [[3, sqrt(eps)], [sqrt(6), sqrt(eps)]] + w_3 = [4/3, 1/sqrt(6)] + w_norm = w_3 / (initial_w + 1e-7) + weights = [0.5, 0.5] * exp(w_norm) / normalize + """ + + J1 = tensor_([[2.0, 0.0], [1.0, 0.0]]) + J2 = tensor_([[1.0, 0.0], [2.0, 0.0]]) + J3 = tensor_([[2.0, 0.0], [1.0, 0.0]]) + W = ExcessMTLWeighting(n_warmup_steps=2, robust_step_size=1.0) + + W(J1) # warmup step 1 — grad_sum becomes J1**2 + W(J2) # warmup step 2 — grad_sum becomes J1**2 + J2**2 + + grad_sum_1 = J1**2 + h_1 = torch.sqrt(grad_sum_1 + 1e-7) + w_1 = (J1**2 / h_1).sum(dim=1) + + grad_sum_2 = grad_sum_1 + J2**2 + h_2 = torch.sqrt(grad_sum_2 + 1e-7) + w_2 = (J2**2 / h_2).sum(dim=1) + + initial_w = (w_1 + w_2) / 2 # Appendix C.1 baseline + + grad_sum_3 = grad_sum_2 + J3**2 + h_3 = torch.sqrt(grad_sum_3 + 1e-7) + w_3 = (J3**2 / h_3).sum(dim=1) + w_norm = w_3 / (initial_w + 1e-7) + pre_norm = tensor_([0.5, 0.5]) * torch.exp(w_norm) + expected = pre_norm / pre_norm.sum() + + assert_close(W(J3), expected) + + +def test_n_steps_resets_on_m_change() -> None: + """When the number of objectives changes the warmup counter must restart.""" + + W = ExcessMTLWeighting(n_warmup_steps=10) + # Burn through 5 warmup steps + for _ in range(5): + W(randn_((3, 8))) + + # Switch to 2 objectives — state including step counter resets + fresh = ExcessMTLWeighting(n_warmup_steps=10) + J = randn_((2, 8)) + assert_close(W(J), fresh(J)) + + +def test_non_differentiable() -> None: + """The _NonDifferentiable mixin must prevent autograd graph construction.""" + + J = randn_((3, 8)) + J.requires_grad_(True) + W = ExcessMTLWeighting() + weights = W(J) + assert not weights.requires_grad From 20b7f08e53d8b9f2a0b419c2ce1eb412520f896a Mon Sep 17 00:00:00 2001 From: Khush Date: Tue, 23 Jun 2026 12:04:29 -0400 Subject: [PATCH 2/6] Address Valerian's review comments on ExcessMTL --- docs/source/docs/aggregation/excess_mtl.rst | 3 + src/torchjd/aggregation/__init__.py | 3 +- src/torchjd/aggregation/_excess_mtl.py | 135 ++++++++++++-------- tests/unit/aggregation/test_excess_mtl.py | 41 ++---- 4 files changed, 99 insertions(+), 83 deletions(-) diff --git a/docs/source/docs/aggregation/excess_mtl.rst b/docs/source/docs/aggregation/excess_mtl.rst index e79fcc88..2effb570 100644 --- a/docs/source/docs/aggregation/excess_mtl.rst +++ b/docs/source/docs/aggregation/excess_mtl.rst @@ -3,5 +3,8 @@ ExcessMTL ========= +.. autoclass:: torchjd.aggregation.ExcessMTL + :members: __call__, reset + .. autoclass:: torchjd.aggregation.ExcessMTLWeighting :members: __call__, reset diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 61cae873..55e855ee 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -45,7 +45,7 @@ from ._constant import Constant, ConstantWeighting from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting -from ._excess_mtl import ExcessMTLWeighting +from ._excess_mtl import ExcessMTL, ExcessMTLWeighting from ._fairgrad import FairGrad, FairGradWeighting from ._graddrop import GradDrop from ._gradvac import GradVac, GradVacWeighting @@ -75,6 +75,7 @@ "CRMOGMWeighting", "DualProj", "DualProjWeighting", + "ExcessMTL", "ExcessMTLWeighting", "FairGrad", "FairGradWeighting", diff --git a/src/torchjd/aggregation/_excess_mtl.py b/src/torchjd/aggregation/_excess_mtl.py index 266fe25c..f5d6e202 100644 --- a/src/torchjd/aggregation/_excess_mtl.py +++ b/src/torchjd/aggregation/_excess_mtl.py @@ -11,6 +11,7 @@ from torchjd.aggregation._mixins import _NonDifferentiable from torchjd.linalg import Matrix +from ._aggregator_bases import WeightedAggregator from ._weighting_bases import _MatrixWeighting @@ -29,64 +30,37 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): Must be positive. :param n_warmup_steps: Number of forward calls during which weights stay uniform (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess - risk is set to the average excess risk observed during warmup. When ``0`` (default), the - first call's excess risk is used as the baseline and weights are updated immediately - (matching the official implementation). + risk is then set to the average excess risk observed during warmup. When ``0``, the first + call's excess risk is used immediately as the baseline. The default ``1`` matches the + behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends + collecting statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``. .. warning:: The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients - across **all** calls, where :math:`n` is the total number of model parameters. For large + across calls, where :math:`n` is the total number of model parameters. For large models this can be a significant memory cost. Call :meth:`reset` between experiments. .. note:: The weight update is adapted from the `official implementation `_ and `LibMTL `_. - The warmup strategy follows Appendix C.1 of the paper, which recommends collecting - gradient statistics for several epochs before beginning weight updates; set - ``n_warmup_steps`` accordingly (e.g. ``3 * len(dataloader)``). - - .. admonition:: Example - - .. testcode:: - - import torch - from torch.nn import Linear, MSELoss, ReLU, Sequential - from torch.optim import SGD - - from torchjd import autojac - from torchjd.aggregation import ExcessMTLWeighting, WeightedAggregator - from torchjd.autojac import jac_to_grad - - inputs = torch.randn(8, 5) - targets = torch.randn(8, 2) - - model = Sequential(Linear(5, 4), ReLU(), Linear(4, 2)) - optimizer = SGD(model.parameters()) - criterion = MSELoss() - aggregator = WeightedAggregator(ExcessMTLWeighting()) - - outputs = model(inputs) - losses = [criterion(outputs[:, i], targets[:, i]) for i in range(2)] - autojac.backward(losses) - jac_to_grad(model.parameters(), aggregator) - optimizer.step() - optimizer.zero_grad() + Unlike those implementations, which initialize task weights to ``1``, we follow the paper + and initialize them to ``1/m`` so that they always lie on the probability simplex. """ def __init__( self, robust_step_size: float = 1.0, - n_warmup_steps: int = 0, + n_warmup_steps: int = 1, ) -> None: super().__init__() self.robust_step_size = robust_step_size self.n_warmup_steps = n_warmup_steps self.register_buffer("_weights", None) - self.register_buffer("_grad_sum", None) + self.register_buffer("_sq_grad_sum", None) self.register_buffer("_initial_w", None) self.register_buffer("_warmup_w_sum", None) - self.register_buffer("_n_steps", torch.zeros((), dtype=torch.long)) + self._n_steps: int = 0 self._state_key: tuple[int, int, torch.dtype, torch.device] | None = None @property @@ -118,31 +92,31 @@ def reset(self) -> None: warmup.""" self._weights = None - self._grad_sum = None + self._sq_grad_sum = None self._initial_w = None self._warmup_w_sum = None - self._n_steps.zero_() + self._n_steps = 0 self._state_key = None def forward(self, matrix: Matrix, /) -> Tensor: self._ensure_state(matrix) + sq_matrix = matrix.detach() ** 2 + # Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7) - grad_sum = cast(Tensor, self._grad_sum) - grad_sum = grad_sum + matrix.detach() ** 2 - self._grad_sum = grad_sum + sq_grad_sum = cast(Tensor, self._sq_grad_sum) + sq_grad_sum += sq_matrix # Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6) - h = torch.sqrt(grad_sum + 1e-7) - w = (matrix.detach() ** 2 / h).sum(dim=1) # shape [m] + h = torch.sqrt(sq_grad_sum + 1e-7) + w = (sq_matrix / h).sum(dim=1) # shape [m] - n_steps = int(self._n_steps.item()) - self._n_steps = self._n_steps + 1 + n_steps = self._n_steps + self._n_steps += 1 # Warmup: collect excess risk stats but return uniform weights if n_steps < self._n_warmup_steps: - warmup_w_sum = self._warmup_w_sum - self._warmup_w_sum = w if warmup_w_sum is None else cast(Tensor, warmup_w_sum) + w + cast(Tensor, self._warmup_w_sum).add_(w) return cast(Tensor, self._weights) # Set baseline on the first non-warmup call @@ -152,7 +126,7 @@ def forward(self, matrix: Matrix, /) -> Tensor: self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps w = w / (cast(Tensor, self._initial_w) + 1e-7) else: - # Official impl behaviour: first call's excess is the baseline; use w raw + # Official impl behavior: first call's excess is the baseline; use w raw self._initial_w = w else: w = w / (cast(Tensor, self._initial_w) + 1e-7) @@ -166,14 +140,14 @@ def forward(self, matrix: Matrix, /) -> Tensor: def _ensure_state(self, matrix: Matrix) -> None: key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device) - if self._state_key == key and self._grad_sum is not None: + if self._state_key == key and self._sq_grad_sum is not None: return m, n = matrix.shape - self._grad_sum = matrix.new_zeros(m, n) + self._sq_grad_sum = matrix.new_zeros(m, n) + self._warmup_w_sum = matrix.new_zeros(m) self._weights = matrix.new_full((m,), 1.0 / m) self._initial_w = None - self._warmup_w_sum = None - self._n_steps.zero_() + self._n_steps = 0 self._state_key = key def __repr__(self) -> str: @@ -182,3 +156,58 @@ def __repr__(self) -> str: f"robust_step_size={self.robust_step_size!r}, " f"n_warmup_steps={self.n_warmup_steps!r})" ) + + +class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable): + r""" + :class:`~torchjd.aggregation.WeightedAggregator` from `Robust Multi-Task Learning with Excess + Risks `_ (ICML 2024). + + At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven + by per-task excess risk estimates. See :class:`~torchjd.aggregation.ExcessMTLWeighting` for + details on the algorithm and state management. + + :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update. + Must be positive. + :param n_warmup_steps: Number of forward calls during which weights stay uniform + (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. When ``0``, the first + call's excess risk is used as the baseline immediately. Defaults to ``1``. + """ + + weighting: ExcessMTLWeighting + + def __init__( + self, + robust_step_size: float = 1.0, + n_warmup_steps: int = 1, + ) -> None: + super().__init__(ExcessMTLWeighting(robust_step_size, n_warmup_steps)) + + @property + def robust_step_size(self) -> float: + return self.weighting.robust_step_size + + @robust_step_size.setter + def robust_step_size(self, value: float) -> None: + self.weighting.robust_step_size = value + + @property + def n_warmup_steps(self) -> int: + return self.weighting.n_warmup_steps + + @n_warmup_steps.setter + def n_warmup_steps(self, value: int) -> None: + self.weighting.n_warmup_steps = value + + def reset(self) -> None: + """Clears all state so the next forward starts from uniform weights and re-enters + warmup.""" + + self.weighting.reset() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"robust_step_size={self.robust_step_size!r}, " + f"n_warmup_steps={self.n_warmup_steps!r})" + ) diff --git a/tests/unit/aggregation/test_excess_mtl.py b/tests/unit/aggregation/test_excess_mtl.py index f0fcb9dc..5bc3da56 100644 --- a/tests/unit/aggregation/test_excess_mtl.py +++ b/tests/unit/aggregation/test_excess_mtl.py @@ -88,23 +88,6 @@ def test_weights_change_after_warmup() -> None: assert weights[0] > weights[1] -def test_update_recurrence() -> None: - """Verify the first weight update manually (n_warmup_steps=0, LibMTL behaviour). - - With J = [[2., 0.], [1., 0.]] and robust_step_size=1.0: - grad_sum = [[4., 0.], [1., 0.]] - h ≈ [[2., sqrt(eps)], [1., sqrt(eps)]] (eps = 1e-7, negligible in float32 for nonzero entries) - w = [4/2 + 0, 1/1 + 0] = [2, 1] - initial_w = [2, 1] (first call: save raw excess as baseline) - weights = [exp(2), exp(1)] / (exp(2) + exp(1)) - """ - J = tensor_([[2.0, 0.0], [1.0, 0.0]]) - W = ExcessMTLWeighting(robust_step_size=1.0) - e2 = torch.exp(tensor_(2.0)) - e1 = torch.exp(tensor_(1.0)) - assert_close(W(J), tensor_([e2 / (e2 + e1), e1 / (e2 + e1)])) - - def test_two_consecutive_steps() -> None: """Verify warm-started carry-over across two calls. @@ -119,7 +102,7 @@ def test_two_consecutive_steps() -> None: """ J1 = tensor_([[2.0, 0.0], [1.0, 0.0]]) J2 = tensor_([[1.0, 0.0], [2.0, 0.0]]) - W = ExcessMTLWeighting(robust_step_size=1.0) + W = ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0) e2 = torch.exp(tensor_(2.0)) e1 = torch.exp(tensor_(1.0)) @@ -140,17 +123,17 @@ def test_warmup_baseline_is_average() -> None: With n_warmup_steps=2 and J1=[[2,0],[1,0]], J2=[[1,0],[2,0]]: - Warmup call 1 — grad_sum_1 = J1**2 = [[4,0],[1,0]]: + Warmup call 1 — sq_grad_sum_1 = J1**2 = [[4,0],[1,0]]: h_1 ≈ [[2, sqrt(eps)], [1, sqrt(eps)]] w_1 = [4/2, 1/1] = [2, 1] - Warmup call 2 — grad_sum_2 = J1**2 + J2**2 = [[5,0],[5,0]]: + Warmup call 2 — sq_grad_sum_2 = J1**2 + J2**2 = [[5,0],[5,0]]: h_2 ≈ [[sqrt(5), sqrt(eps)], [sqrt(5), sqrt(eps)]] w_2 = [1/sqrt(5), 4/sqrt(5)] initial_w = (w_1 + w_2) / 2 (Appendix C.1 average) - Post-warmup call 3 with J3 = J1 — grad_sum_3 = [[9,0],[6,0]]: + Post-warmup call 3 with J3 = J1 — sq_grad_sum_3 = [[9,0],[6,0]]: h_3 ≈ [[3, sqrt(eps)], [sqrt(6), sqrt(eps)]] w_3 = [4/3, 1/sqrt(6)] w_norm = w_3 / (initial_w + 1e-7) @@ -162,21 +145,21 @@ def test_warmup_baseline_is_average() -> None: J3 = tensor_([[2.0, 0.0], [1.0, 0.0]]) W = ExcessMTLWeighting(n_warmup_steps=2, robust_step_size=1.0) - W(J1) # warmup step 1 — grad_sum becomes J1**2 - W(J2) # warmup step 2 — grad_sum becomes J1**2 + J2**2 + W(J1) # warmup step 1 — sq_grad_sum becomes J1**2 + W(J2) # warmup step 2 — sq_grad_sum becomes J1**2 + J2**2 - grad_sum_1 = J1**2 - h_1 = torch.sqrt(grad_sum_1 + 1e-7) + sq_grad_sum_1 = J1**2 + h_1 = torch.sqrt(sq_grad_sum_1 + 1e-7) w_1 = (J1**2 / h_1).sum(dim=1) - grad_sum_2 = grad_sum_1 + J2**2 - h_2 = torch.sqrt(grad_sum_2 + 1e-7) + sq_grad_sum_2 = sq_grad_sum_1 + J2**2 + h_2 = torch.sqrt(sq_grad_sum_2 + 1e-7) w_2 = (J2**2 / h_2).sum(dim=1) initial_w = (w_1 + w_2) / 2 # Appendix C.1 baseline - grad_sum_3 = grad_sum_2 + J3**2 - h_3 = torch.sqrt(grad_sum_3 + 1e-7) + sq_grad_sum_3 = sq_grad_sum_2 + J3**2 + h_3 = torch.sqrt(sq_grad_sum_3 + 1e-7) w_3 = (J3**2 / h_3).sum(dim=1) w_norm = w_3 / (initial_w + 1e-7) pre_norm = tensor_([0.5, 0.5]) * torch.exp(w_norm) From ceaaf1289c32334bb543fc2b6d41d688f8074d06 Mon Sep 17 00:00:00 2001 From: Khush Date: Tue, 23 Jun 2026 12:18:05 -0400 Subject: [PATCH 3/6] test: Add ExcessMTL aggregator coverage and fix redundant casts --- src/torchjd/aggregation/_excess_mtl.py | 4 +-- tests/unit/aggregation/test_excess_mtl.py | 33 ++++++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/torchjd/aggregation/_excess_mtl.py b/src/torchjd/aggregation/_excess_mtl.py index f5d6e202..958ee4bb 100644 --- a/src/torchjd/aggregation/_excess_mtl.py +++ b/src/torchjd/aggregation/_excess_mtl.py @@ -124,12 +124,12 @@ def forward(self, matrix: Matrix, /) -> Tensor: if self._n_warmup_steps > 0: # Average excess risk observed during warmup (Appendix C.1) self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps - w = w / (cast(Tensor, self._initial_w) + 1e-7) + w = w / (self._initial_w + 1e-7) else: # Official impl behavior: first call's excess is the baseline; use w raw self._initial_w = w else: - w = w / (cast(Tensor, self._initial_w) + 1e-7) + w = w / (self._initial_w + 1e-7) # Exponentiated gradient weight update (Equation 9) weights = cast(Tensor, self._weights) diff --git a/tests/unit/aggregation/test_excess_mtl.py b/tests/unit/aggregation/test_excess_mtl.py index 5bc3da56..ec6f2abc 100644 --- a/tests/unit/aggregation/test_excess_mtl.py +++ b/tests/unit/aggregation/test_excess_mtl.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation._excess_mtl import ExcessMTLWeighting +from torchjd.aggregation._excess_mtl import ExcessMTL, ExcessMTLWeighting def test_representations() -> None: @@ -190,3 +190,34 @@ def test_non_differentiable() -> None: W = ExcessMTLWeighting() weights = W(J) assert not weights.requires_grad + + +# ExcessMTL (aggregator wrapper) tests + + +def test_excess_mtl_representations() -> None: + agg = ExcessMTL(robust_step_size=2.0, n_warmup_steps=3) + assert repr(agg) == "ExcessMTL(robust_step_size=2.0, n_warmup_steps=3)" + + +def test_excess_mtl_properties_delegate() -> None: + agg = ExcessMTL(robust_step_size=1.0, n_warmup_steps=0) + assert agg.robust_step_size == 1.0 + assert agg.n_warmup_steps == 0 + + agg.robust_step_size = 0.5 + assert agg.robust_step_size == 0.5 + assert agg.weighting.robust_step_size == 0.5 + + agg.n_warmup_steps = 5 + assert agg.n_warmup_steps == 5 + assert agg.weighting.n_warmup_steps == 5 + + +def test_excess_mtl_reset_delegates() -> None: + J = randn_((3, 8)) + agg = ExcessMTL(n_warmup_steps=0) + first = agg(J) + agg(J) + agg.reset() + assert_close(first, agg(J)) From 25bd5ac347002fb41653ac3aa506ad2172d416ae Mon Sep 17 00:00:00 2001 From: Khush Date: Wed, 24 Jun 2026 11:12:24 -0400 Subject: [PATCH 4/6] Address Valerian's review comments on ExcessMTL --- src/torchjd/aggregation/_excess_mtl.py | 58 ++++++++++++++------------ 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/src/torchjd/aggregation/_excess_mtl.py b/src/torchjd/aggregation/_excess_mtl.py index 958ee4bb..ff8062fd 100644 --- a/src/torchjd/aggregation/_excess_mtl.py +++ b/src/torchjd/aggregation/_excess_mtl.py @@ -24,16 +24,16 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven by per-task excess risk estimates. The excess risk for task :math:`i` is approximated via a - second-order Taylor expansion (Equations 6-7): + second-order Taylor expansion (Equations 6-7). :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update. Must be positive. :param n_warmup_steps: Number of forward calls during which weights stay uniform (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess - risk is then set to the average excess risk observed during warmup. When ``0``, the first - call's excess risk is used immediately as the baseline. The default ``1`` matches the - behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends - collecting statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``. + risk is then set to the average excess risk observed during warmup. When ``0`` (default), + the first call's excess risk is used immediately as the baseline, matching the behavior of + the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting + statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``. .. warning:: The state tensor :math:`S \in \mathbb{R}^{m \times n}` accumulates squared gradients @@ -51,7 +51,7 @@ class ExcessMTLWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): def __init__( self, robust_step_size: float = 1.0, - n_warmup_steps: int = 1, + n_warmup_steps: int = 0, ) -> None: super().__init__() self.robust_step_size = robust_step_size @@ -87,17 +87,6 @@ def n_warmup_steps(self, value: int) -> None: ) self._n_warmup_steps = value - def reset(self) -> None: - """Clears all state so the next forward starts from uniform weights and re-enters - warmup.""" - - self._weights = None - self._sq_grad_sum = None - self._initial_w = None - self._warmup_w_sum = None - self._n_steps = 0 - self._state_key = None - def forward(self, matrix: Matrix, /) -> Tensor: self._ensure_state(matrix) @@ -105,31 +94,31 @@ def forward(self, matrix: Matrix, /) -> Tensor: # Accumulate squared gradients for AdaGrad-style diagonal Hessian (Equation 7) sq_grad_sum = cast(Tensor, self._sq_grad_sum) - sq_grad_sum += sq_matrix + sq_grad_sum.add_(sq_matrix) # Excess risk proxy: Ê_i ≈ g_i^T H_i^{-1} g_i (Equation 6) h = torch.sqrt(sq_grad_sum + 1e-7) w = (sq_matrix / h).sum(dim=1) # shape [m] - n_steps = self._n_steps - self._n_steps += 1 - # Warmup: collect excess risk stats but return uniform weights - if n_steps < self._n_warmup_steps: + if self._n_steps < self._n_warmup_steps: cast(Tensor, self._warmup_w_sum).add_(w) + self._n_steps += 1 return cast(Tensor, self._weights) + self._n_steps += 1 + # Set baseline on the first non-warmup call if self._initial_w is None: if self._n_warmup_steps > 0: # Average excess risk observed during warmup (Appendix C.1) self._initial_w = cast(Tensor, self._warmup_w_sum) / self._n_warmup_steps - w = w / (self._initial_w + 1e-7) + w = w / (self._initial_w + 1e-7) # Scale processing (Section 3.2) else: # Official impl behavior: first call's excess is the baseline; use w raw self._initial_w = w else: - w = w / (self._initial_w + 1e-7) + w = w / (self._initial_w + 1e-7) # Scale processing (Section 3.2) # Exponentiated gradient weight update (Equation 9) weights = cast(Tensor, self._weights) @@ -138,6 +127,17 @@ def forward(self, matrix: Matrix, /) -> Tensor: self._weights = weights return weights + def reset(self) -> None: + """Clears all state so the next forward starts from uniform weights and re-enters + warmup.""" + + self._weights = None + self._sq_grad_sum = None + self._initial_w = None + self._warmup_w_sum = None + self._n_steps = 0 + self._state_key = None + def _ensure_state(self, matrix: Matrix) -> None: key = (matrix.shape[0], matrix.shape[1], matrix.dtype, matrix.device) if self._state_key == key and self._sq_grad_sum is not None: @@ -160,6 +160,7 @@ def __repr__(self) -> str: class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable): r""" + :class:`~torchjd.Stateful` :class:`~torchjd.aggregation.WeightedAggregator` from `Robust Multi-Task Learning with Excess Risks `_ (ICML 2024). @@ -170,8 +171,11 @@ class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable): :param robust_step_size: Step size :math:`\eta_\alpha` for the exponentiated weight update. Must be positive. :param n_warmup_steps: Number of forward calls during which weights stay uniform - (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. When ``0``, the first - call's excess risk is used as the baseline immediately. Defaults to ``1``. + (:math:`[1/m, \ldots, 1/m]`) and gradient statistics are collected. The baseline excess + risk is then set to the average excess risk observed during warmup. When ``0`` (default), + the first call's excess risk is used immediately as the baseline, matching the behavior of + the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting + statistics for 3 full epochs, i.e. ``n_warmup_steps = 3 * len(dataloader)``. """ weighting: ExcessMTLWeighting @@ -179,7 +183,7 @@ class ExcessMTL(WeightedAggregator, Stateful, _NonDifferentiable): def __init__( self, robust_step_size: float = 1.0, - n_warmup_steps: int = 1, + n_warmup_steps: int = 0, ) -> None: super().__init__(ExcessMTLWeighting(robust_step_size, n_warmup_steps)) From 616eb8eb829047bade14b6bdce7fe776a8e317cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Jun 2026 21:06:03 +0200 Subject: [PATCH 5/6] Fix changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8147aa4..62b491a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ changelog does not include internal changes that do not affect the user. Algorithm Based on Decomposition](https://ieeexplore.ieee.org/document/4358754) (IEEE TEVC 2007), a `Scalarizer` that decomposes the values into a component along a preference direction and a penalized perpendicular component. -- Added `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin. +- Added `ExcessMTL` and `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin. ## [0.15.0] - 2026-06-15 From 0e140a38fc12257a5ceb12a3e2a03d6cf2c16041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Jun 2026 21:16:48 +0200 Subject: [PATCH 6/6] test: Add expected_structure test and fix changelog placement for ExcessMTL Co-Authored-By: Claude Sonnet 4.6 --- CHANGELOG.md | 10 +++++++++- tests/unit/aggregation/test_excess_mtl.py | 14 +++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62b491a3..eb67ee64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,15 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `ExcessMTL` and `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess + Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). `ExcessMTLWeighting` is a + stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update + driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style + diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect + gradient statistics before weight updates begin. + ## [0.16.0] - 2026-06-22 ### Added @@ -20,7 +29,6 @@ changelog does not include internal changes that do not affect the user. Algorithm Based on Decomposition](https://ieeexplore.ieee.org/document/4358754) (IEEE TEVC 2007), a `Scalarizer` that decomposes the values into a component along a preference direction and a penalized perpendicular component. -- Added `ExcessMTL` and `ExcessMTLWeighting` from [Robust Multi-Task Learning with Excess Risks](https://proceedings.mlr.press/v235/he24n.html) (ICML 2024). It is a stateful `Weighting` that maintains task weights across calls via an exponentiated gradient update driven by per-task excess risk estimates. The excess risk is approximated using an AdaGrad-style diagonal Hessian. An optional `n_warmup_steps` parameter controls how many forward calls collect gradient statistics before weight updates begin. ## [0.15.0] - 2026-06-15 diff --git a/tests/unit/aggregation/test_excess_mtl.py b/tests/unit/aggregation/test_excess_mtl.py index ec6f2abc..04367062 100644 --- a/tests/unit/aggregation/test_excess_mtl.py +++ b/tests/unit/aggregation/test_excess_mtl.py @@ -1,10 +1,21 @@ import torch -from pytest import raises +from pytest import mark, raises +from torch import Tensor from torch.testing import assert_close from utils.tensors import randn_, tensor_ from torchjd.aggregation._excess_mtl import ExcessMTL, ExcessMTLWeighting +from ._asserts import assert_expected_structure +from ._inputs import typical_matrices + +typical_pairs = [(ExcessMTL(), m) for m in typical_matrices] + + +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_expected_structure(aggregator: ExcessMTL, matrix: Tensor) -> None: + assert_expected_structure(aggregator, matrix) + def test_representations() -> None: W = ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0) @@ -198,6 +209,7 @@ def test_non_differentiable() -> None: def test_excess_mtl_representations() -> None: agg = ExcessMTL(robust_step_size=2.0, n_warmup_steps=3) assert repr(agg) == "ExcessMTL(robust_step_size=2.0, n_warmup_steps=3)" + assert str(agg) == "ExcessMTL" def test_excess_mtl_properties_delegate() -> None: