diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index c6ae3b49..6c13ddc8 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -975,26 +975,39 @@ def _validate_data( def predict(self, data: pd.DataFrame) -> np.ndarray: """ - Predict outcomes using fitted model. + Predict outcomes using the fitted model. + + Out-of-sample prediction is intentionally unsupported pending a broader + post-estimation design for estimator result objects. For fitted + training-data predictions, use ``results_.fitted_values`` after + :meth:`fit`. Parameters ---------- data : pd.DataFrame - DataFrame with same structure as training data. + Candidate prediction data. Currently unused because out-of-sample + prediction is unsupported. Returns ------- np.ndarray Predicted values. + + Raises + ------ + RuntimeError + If called before :meth:`fit`. + NotImplementedError + Always raised after fitting until the broader post-estimation + prediction contract is designed. """ if not self.is_fitted_: raise RuntimeError("Model must be fitted before calling predict()") - # This is a placeholder - would need to store column names - # for full implementation raise NotImplementedError( - "predict() is not yet implemented. " - "Use results_.fitted_values for training data predictions." + "out-of-sample predict() is unsupported pending a broader " + "post-estimation design. Use results_.fitted_values for fitted " + "training-data predictions." ) def get_params(self) -> Dict[str, Any]: diff --git a/docs/api/estimators.rst b/docs/api/estimators.rst index d9e63879..5ba2e632 100644 --- a/docs/api/estimators.rst +++ b/docs/api/estimators.rst @@ -30,6 +30,11 @@ DifferenceInDifferences (alias: ``DiD``) Basic 2x2 DiD estimator. +``DifferenceInDifferences.predict()`` is present for sklearn-like +discoverability, but out-of-sample prediction is not currently supported. Use +``results_.fitted_values`` for fitted training-data predictions until a broader +post-estimation result-object contract is designed. + .. autoclass:: diff_diff.DifferenceInDifferences :no-index: :members: @@ -42,6 +47,7 @@ Basic 2x2 DiD estimator. .. autosummary:: ~DifferenceInDifferences.fit + ~DifferenceInDifferences.predict ~DifferenceInDifferences.get_params ~DifferenceInDifferences.set_params @@ -84,4 +90,3 @@ Synthetic control combined with DiD (Arkhangelsky et al. 2021). :undoc-members: :show-inheritance: :inherited-members: - diff --git a/tests/test_methodology_did.py b/tests/test_methodology_did.py index 7d91a6a0..a0f48ae1 100644 --- a/tests/test_methodology_did.py +++ b/tests/test_methodology_did.py @@ -1549,3 +1549,16 @@ def test_residuals_and_fitted_values(self): assert np.allclose(reconstructed, original), \ "Residuals + fitted should equal original outcome" + + def test_predict_contract_points_to_fitted_values(self): + """predict() is intentionally unsupported until post-estimation is designed.""" + data, _ = generate_hand_calculable_data() + + did = DifferenceInDifferences() + did.fit(data, outcome='outcome', treatment='treated', time='post') + + with pytest.raises( + NotImplementedError, + match="out-of-sample.*post-estimation.*results_\\.fitted_values", + ): + did.predict(data)