Weight tying improvements#456
Conversation
le1nux
left a comment
There was a problem hiding this comment.
Automated review (Claude Code, /code-review, high effort). The core fix is correct: with tying, transformer.wte.weight is the output projection, so initializing it at 1/sqrt(n_embd) instead of std=1 is right, and the regression test genuinely fails without it. All 14 tests pass. Findings below concern the config contract and code clarity, not the central change.
- #1 (medium): the initializer's
use_weight_tyingdefault decouples it from the model's flag with no cross-check — see inline comment. - #2 (low / altitude): define-then-override on the same
wteregex key — see inline comment. - #3 (low / test cleanup):
_make_norm_configre-encodes a mapping theLayerNormsenum already owns — see inline comment.
Nit (not posted inline): has_tied_word_embeddings uses "wte" not in self.transformer while forward_impl uses hasattr(self.transformer, "wte") — behaviorally identical for nn.ModuleDict, pure consistency nit.
| num_layers: Annotated[int, Field(strict=True, gt=0)] | ||
| n_embd: Annotated[int, Field(strict=True, gt=0)] | ||
| use_weight_tying: bool | ||
| use_weight_tying: bool = False |
There was a problem hiding this comment.
Medium — initializer's use_weight_tying is now decoupled from the model's, with no cross-check.
This field changed from required to = False. The model's use_weight_tying and the initializer's are set independently in YAML (the production config and llama3_config_initalization.yaml already omit it here and rely on the default), and nothing validates that they agree. On a mismatch:
- Model tied + initializer defaults to False: the dict registers a
transformer\.lm_head\.weightregex, but a tied tensor only ever surfaces fromnamed_parameters()aswte.weight(verified: wte is registered first), so that regex gets 0 hits and thehits == 0guard raisesValueError: Regex transformer\.lm_head\.weight did not match any FQNs. The model specification probably does not match LLama3.Loud, but the message misdirects debugging toward "not a Llama3 model" when the real cause is a tying-flag mismatch (and the tied tensor is set tostd=1before the raise). - Initializer True + model untied: no
lm_headregex exists, solm_head.weightmatches nothing and only emits alogger.warning— the output projection is left uninitialized. Silent.
For a PR whose purpose is fixing tied-init, relaxing this contract reopens a footgun. Suggest keeping the field required, or adding a validator asserting initializer.use_weight_tying == model.has_tied_word_embeddings (and improving that error message).
| # small output std (1/sqrt(n_embd)) instead of the embedding std of 1. | ||
| # Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too | ||
| # large at init, causing the initial loss/grad norm to explode. | ||
| self.regex_to_init[r"transformer\.wte\.weight"] = ( |
There was a problem hiding this comment.
Low (altitude) — define-then-override on the same wte regex key.
Line 42 already registers transformer\.wte\.weight with std=1; this else branch re-inserts the same key with std=1/sqrt(n_embd). It's correct only because the two regex strings are byte-identical, so the dict overwrites rather than double-matching (verified benign). But the std=1 entry is dead whenever tying is on, the wte init spec is split across two distant locations, and this output-projection arg dict is duplicated verbatim with the lm_head branch above. A future edit to the line 42 std=1 would be silently discarded for tied models.
Cleaner: compute wte's args once, conditionally, and register a single entry — mirroring how lm_head is handled — or pick the key (wte vs lm_head) by the flag with one shared output-projection arg dict.
| norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) | ||
| ) | ||
|
|
||
| def _make_norm_config() -> LayerNormWrapperConfig: |
There was a problem hiding this comment.
Low (cleanup) — test helper re-encodes a mapping the enum already owns.
_make_norm_config hand-maps pytorch_rms_norm -> PytorchRMSLayerNormConfig and everything else -> LayerNormConfig. The LayerNorms enum already pairs each norm type with its config class; a future test passing a third norm type (e.g. rms_norm) silently falls through to the wrong config and produces a confusing construction error rather than using the correct config class. Minor, test-only.
…ject non-GPT2 models
This PR addresses the initialization of tied weights in the Llama3Initializer, ensuring it correctly infers weight tying from the model and rejects non-GPT2 models. It also adds tests for weight tying functionality and checks for proper initialization in pipeline parallelism scenarios.
General Changes
Breaking Changes
use_weight_tyingparameter from the Llama3Initializer constructor.Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)