Skip to content

Weight tying improvements#456

Open
le1nux wants to merge 5 commits into
mainfrom
3B_training_prep
Open

Weight tying improvements#456
le1nux wants to merge 5 commits into
mainfrom
3B_training_prep

Conversation

@le1nux

@le1nux le1nux commented Jun 20, 2026

Copy link
Copy Markdown
Member

This PR addresses the initialization of tied weights in the Llama3Initializer, ensuring it correctly infers weight tying from the model and rejects non-GPT2 models. It also adds tests for weight tying functionality and checks for proper initialization in pipeline parallelism scenarios.

General Changes

  • Fixed initialization logic for tied weights.
  • Added tests for weight tying and Llama3 initialization checks.
  • Updated handling of weight tying in pipeline parallelism.

Breaking Changes

  • Removed the use_weight_tying parameter from the Llama3Initializer constructor.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

@le1nux le1nux requested a review from rrutmann June 27, 2026 22:30

@le1nux le1nux left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Automated review (Claude Code, /code-review, high effort). The core fix is correct: with tying, transformer.wte.weight is the output projection, so initializing it at 1/sqrt(n_embd) instead of std=1 is right, and the regression test genuinely fails without it. All 14 tests pass. Findings below concern the config contract and code clarity, not the central change.

  • #1 (medium): the initializer's use_weight_tying default decouples it from the model's flag with no cross-check — see inline comment.
  • #2 (low / altitude): define-then-override on the same wte regex key — see inline comment.
  • #3 (low / test cleanup): _make_norm_config re-encodes a mapping the LayerNorms enum already owns — see inline comment.

Nit (not posted inline): has_tied_word_embeddings uses "wte" not in self.transformer while forward_impl uses hasattr(self.transformer, "wte") — behaviorally identical for nn.ModuleDict, pure consistency nit.

num_layers: Annotated[int, Field(strict=True, gt=0)]
n_embd: Annotated[int, Field(strict=True, gt=0)]
use_weight_tying: bool
use_weight_tying: bool = False

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Medium — initializer's use_weight_tying is now decoupled from the model's, with no cross-check.

This field changed from required to = False. The model's use_weight_tying and the initializer's are set independently in YAML (the production config and llama3_config_initalization.yaml already omit it here and rely on the default), and nothing validates that they agree. On a mismatch:

  • Model tied + initializer defaults to False: the dict registers a transformer\.lm_head\.weight regex, but a tied tensor only ever surfaces from named_parameters() as wte.weight (verified: wte is registered first), so that regex gets 0 hits and the hits == 0 guard raises ValueError: Regex transformer\.lm_head\.weight did not match any FQNs. The model specification probably does not match LLama3. Loud, but the message misdirects debugging toward "not a Llama3 model" when the real cause is a tying-flag mismatch (and the tied tensor is set to std=1 before the raise).
  • Initializer True + model untied: no lm_head regex exists, so lm_head.weight matches nothing and only emits a logger.warning — the output projection is left uninitialized. Silent.

For a PR whose purpose is fixing tied-init, relaxing this contract reopens a footgun. Suggest keeping the field required, or adding a validator asserting initializer.use_weight_tying == model.has_tied_word_embeddings (and improving that error message).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

# small output std (1/sqrt(n_embd)) instead of the embedding std of 1.
# Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too
# large at init, causing the initial loss/grad norm to explode.
self.regex_to_init[r"transformer\.wte\.weight"] = (

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low (altitude) — define-then-override on the same wte regex key.

Line 42 already registers transformer\.wte\.weight with std=1; this else branch re-inserts the same key with std=1/sqrt(n_embd). It's correct only because the two regex strings are byte-identical, so the dict overwrites rather than double-matching (verified benign). But the std=1 entry is dead whenever tying is on, the wte init spec is split across two distant locations, and this output-projection arg dict is duplicated verbatim with the lm_head branch above. A future edit to the line 42 std=1 would be silently discarded for tied models.

Cleaner: compute wte's args once, conditionally, and register a single entry — mirroring how lm_head is handled — or pick the key (wte vs lm_head) by the flag with one shared output-projection arg dict.

norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
)

def _make_norm_config() -> LayerNormWrapperConfig:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low (cleanup) — test helper re-encodes a mapping the enum already owns.

_make_norm_config hand-maps pytorch_rms_norm -> PytorchRMSLayerNormConfig and everything else -> LayerNormConfig. The LayerNorms enum already pairs each norm type with its config class; a future test passing a third norm type (e.g. rms_norm) silently falls through to the wrong config and produces a confusing construction error rather than using the correct config class. Minor, test-only.

@le1nux le1nux changed the title fix: fixed initialization of tied weights in Llama3Initializer Weight tying improvements Jun 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant