# Examples This page collects a few slightly less standard use cases that build on the main workflow from the quickstart and aggregation pages. ## Custom Group Aggregation You can aggregate cohorts into custom groups by passing a `custom_map` to `agg_latt(..., method="cohort_custom")`. ``` python import idid df = idid.sim_stag_panel( n=10_000, T=5, E_cohorts=[0, 2, 3, 4, 5], ) res = idid.estimate( df, cohort="E", time="t", outcome="Y_t", treatment="D_t", unit="id", covariates=["X"], control="never", method="dr", balanced=True, verbose=False, ) custom_map = { 2: "2-3", 3: "2-3", 4: "4-5", 5: "4-5", } agg = idid.agg_latt( res, method="cohort_custom", agg_kwargs={"custom_map": custom_map}, ) agg.summary() ``` Overall summary of ATT's based on custom cohort aggregation: LATT Std. Error [95% Conf. Band] 1.1750 0.1013 0.9765 1.3734 * Custom cohort effects: Cohort Estimate Std. Error [95% Pointwise Conf. Band] 2-3 0.9925 0.1214 0.7545 1.2304 * 4-5 1.3506 0.1584 1.0402 1.6610 * --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust Compare that to: ``` python agg = idid.agg_latt(res, method="cohort") agg.summary() ``` Overall summary of ATT's based on cohort aggregation: LATT Std. Error [95% Conf. Band] 1.1522 0.0992 0.9577 1.3466 * Cohort effects: Cohort Estimate Std. Error [95% Pointwise Conf. Band] 2 0.9999 0.1625 0.6813 1.3185 * 3 0.9818 0.1793 0.6304 1.3331 * 4 1.5078 0.2244 1.0680 1.9477 * 5 1.0982 0.2087 0.6891 1.5072 * --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust ------------------------------------------------------------------------ ## Group Difference Aggregation To estimate differences between two groups, estimate the cohort-time LATTs within each group and then call `group_diff_idid`. The true difference between the groups equals $1/2$ for all $(e, t)$. ``` python import polars as pl from idid.aggregate import group_diff_idid from idid.plotting import plot_agg, plot_evolution_groups, summarize_evolution df = idid.sim_stag_panel( n=20_000, T=4, E_cohorts=[0, 2, 3, 4], confounded=True, with_group=True, ) print(df.head()) ``` shape: (5, 7) ┌─────┬─────┬─────┬───────────┬─────┬──────────┬─────┐ │ id ┆ E ┆ t ┆ X ┆ D_t ┆ Y_t ┆ F │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ f64 ┆ i64 ┆ f64 ┆ i64 │ ╞═════╪═════╪═════╪═══════════╪═════╪══════════╪═════╡ │ 0 ┆ 4 ┆ 1 ┆ 1.285605 ┆ 1 ┆ 7.275993 ┆ 1 │ │ 0 ┆ 4 ┆ 2 ┆ 1.285605 ┆ 0 ┆ 2.930226 ┆ 1 │ │ 0 ┆ 4 ┆ 3 ┆ 1.285605 ┆ 0 ┆ 3.64175 ┆ 1 │ │ 0 ┆ 4 ┆ 4 ┆ 1.285605 ┆ 1 ┆ 7.260314 ┆ 1 │ │ 1 ┆ 4 ┆ 1 ┆ -0.303553 ┆ 1 ┆ 3.297906 ┆ 0 │ └─────┴─────┴─────┴───────────┴─────┴──────────┴─────┘ ``` python gp0 = summarize_evolution(df.filter(pl.col("F").eq(0))) gp1 = summarize_evolution(df.filter(pl.col("F").eq(1))) fig, ax = plot_evolution_groups( gp0, gp1, include_bands=True, ) ``` ![](examples_files/figure-commonmark/cell-7-output-1.png) ``` python ests = {} for f in [0, 1]: res = idid.estimate( df.filter(pl.col("F").eq(f)).with_columns( id=pl.lit("M" if f == 0 else "F") + pl.col("id").cast(pl.Utf8), ), cohort="E", time="t", outcome="Y_t", treatment="D_t", unit="id", covariates=["X"], control="never", method="dr", balanced=True, verbose=False, ) ests[f] = res diff = group_diff_idid(ests[1], ests[0]) print(diff.estimates) diff.summary() ``` shape: (6, 7) ┌─────┬─────┬───────────┬──────────┬──────────┬───────────┬───────────┐ │ E ┆ t ┆ latt ┆ denom ┆ se ┆ lower ┆ upper │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞═════╪═════╪═══════════╪══════════╪══════════╪═══════════╪═══════════╡ │ 2 ┆ 2 ┆ -0.713377 ┆ 0.188419 ┆ 0.486029 ┆ -1.665993 ┆ 0.239239 │ │ 2 ┆ 3 ┆ -0.721623 ┆ 0.19596 ┆ 0.452582 ┆ -1.608683 ┆ 0.165437 │ │ 2 ┆ 4 ┆ -0.831864 ┆ 0.211267 ┆ 0.415799 ┆ -1.64683 ┆ -0.016897 │ │ 3 ┆ 3 ┆ -0.747214 ┆ 0.214008 ┆ 0.41707 ┆ -1.564672 ┆ 0.070243 │ │ 3 ┆ 4 ┆ -0.991648 ┆ 0.225592 ┆ 0.391804 ┆ -1.759584 ┆ -0.223713 │ │ 4 ┆ 4 ┆ -0.591208 ┆ 0.211862 ┆ 0.417223 ┆ -1.408965 ┆ 0.226549 │ └─────┴─────┴───────────┴──────────┴──────────┴───────────┴───────────┘ Cohort-Time Local Average Treatment Effects on the Treated: E t AET(e, t) LATT(e, t) Std. Error [95% Pointwise. Conf. Band] 2 2 0.1884 -0.7134 0.4860 -1.6660 0.2392 2 3 0.1960 -0.7216 0.4526 -1.6087 0.1654 2 4 0.2113 -0.8319 0.4158 -1.6468 -0.0169 * 3 3 0.2140 -0.7472 0.4171 -1.5647 0.0702 3 4 0.2256 -0.9916 0.3918 -1.7596 -0.2237 * 4 4 0.2119 -0.5912 0.4172 -1.4090 0.2265 --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust Aggregated dynamic effects: ``` python agg = idid.agg_latt(diff, method="dynamic") agg.summary() ``` Overall summary of ATT's based on event-study/dynamic aggregation: LATT Std. Error [95% Conf. Band] -0.7932 0.2551 -1.2933 -0.2932 * Dynamic effects: Event time Estimate Std. Error [95% Pointwise Conf. Band] 0 -0.6834 0.2080 -1.0910 -0.2758 * 1 -0.8644 0.2980 -1.4485 -0.2802 * 2 -0.8319 0.4158 -1.6468 -0.0169 * --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust ``` python fig, ax = plot_agg([agg]) ``` ![](examples_files/figure-commonmark/cell-10-output-1.png) ``` python agg_diff = idid.agg_latt(diff, method="dynamic", boot=True) agg_diff.summary() ``` Overall summary of ATT's based on event-study/dynamic aggregation: LATT Std. Error [95% Simult. Conf. Band] -0.7932 0.2619 -1.3063 -0.2802 * Dynamic effects: Event time Estimate Std. Error [95% Simult. Conf. Band] 0 -0.6834 0.2119 -1.1694 -0.1975 * 1 -0.8644 0.3044 -1.5624 -0.1663 * 2 -0.8319 0.4086 -1.7690 0.1053 --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust Multiplier bootstrap: B=1000, c=2.2934, overall c=1.9588 Compare that to: ``` python agg0 = idid.agg_latt(ests[0], method="dynamic", boot=True) agg0.summary() ``` Overall summary of ATT's based on event-study/dynamic aggregation: LATT Std. Error [95% Simult. Conf. Band] 1.1506 0.1841 0.7958 1.5054 * Dynamic effects: Event time Estimate Std. Error [95% Simult. Conf. Band] 0 1.0224 0.1550 0.6773 1.3675 * 1 1.2295 0.2259 0.7265 1.7325 * 2 1.2000 0.2980 0.5366 1.8633 * --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust Multiplier bootstrap: B=1000, c=2.2264, overall c=1.9274 ``` python agg1 = idid.agg_latt(ests[1], method="dynamic", boot=True) agg1.summary() ``` Overall summary of ATT's based on event-study/dynamic aggregation: LATT Std. Error [95% Simult. Conf. Band] 0.3494 0.1739 -0.0236 0.7224 Dynamic effects: Event time Estimate Std. Error [95% Simult. Conf. Band] 0 0.3214 0.1408 0.0042 0.6387 * 1 0.3587 0.2134 -0.1220 0.8394 2 0.3681 0.2922 -0.2901 1.0263 --- Signif. codes: `*' confidence band does not cover 0 Control group: Never treated Estimation Method: Doubly Robust Multiplier bootstrap: B=1000, c=2.2528, overall c=2.1445 We can plot the three aggregated objects: ``` python fig, ax = plot_agg( [agg0, agg1, agg_diff], labels=[ "F = 0", "F = 1", "Group Difference", ], ) ``` ![](examples_files/figure-commonmark/cell-14-output-1.png) ------------------------------------------------------------------------ ## Introspection of estimates The implementation allows for inspection of specific $LATT(e, t)$ estimates. After calling `idid.estimate`, the cell-level results are stored in the `latts` dictionary of the returned object. This is useful when you want to inspect one particular effect and see which treated and control observations entered that comparison. ``` python from idid._types import FailedLATT from idid.estimators import get_controls res = idid.estimate( idid.sim_stag_panel( n=10_000, T=5, E_cohorts=[2, 3, 4, 5], ), cohort="E", time="t", outcome="Y_t", treatment="D_t", unit="id", covariates=["X"], control="notyet", method="dr", balanced=True, verbose=False, ) e = 2 t = 3 latt = res.latts[(e, t)] if isinstance(latt, FailedLATT): raise ValueError("Failed LATT estimation") print(latt) ``` LATT(g=2, t=3, latt=float64(), num=float64(), denom=float64(), ns=7417, ids=DataFrame[7417x2; id, E], IF=ndarray(7417, 1), IF_aet=ndarray(7417, 1), extra=dict[num, den]) Each `LATT` object stores the underlying unit-period ids used for that comparison. In the `ids` data frame, `E = 1` denotes treated observations and `E = 0` denotes controls. ``` python print(latt.ids.head()) print(latt.ids["E"].value_counts()) ``` shape: (5, 2) ┌─────┬─────┐ │ id ┆ E │ │ --- ┆ --- │ │ i64 ┆ i8 │ ╞═════╪═════╡ │ 1 ┆ 1 │ │ 2 ┆ 0 │ │ 3 ┆ 0 │ │ 4 ┆ 0 │ │ 5 ┆ 0 │ └─────┴─────┘ shape: (2, 2) ┌─────┬───────┐ │ E ┆ count │ │ --- ┆ --- │ │ i8 ┆ u32 │ ╞═════╪═══════╡ │ 0 ┆ 4945 │ │ 1 ┆ 2472 │ └─────┴───────┘ A small helper returns the control units for a given `LATT` object and `IDidResult`. ``` python controls = get_controls(latt, res).sort("id") print(controls.head()) ``` shape: (5, 2) ┌─────┬─────┐ │ id ┆ E │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═════╪═════╡ │ 2 ┆ 5 │ │ 3 ┆ 4 │ │ 4 ┆ 5 │ │ 5 ┆ 4 │ │ 6 ┆ 5 │ └─────┴─────┘ ``` python print(f"#Controls = {controls.shape[0]}") print(controls[res.dp.e_col].value_counts().sort(res.dp.e_col)) ``` #Controls = 4945 shape: (2, 2) ┌─────┬───────┐ │ E ┆ count │ │ --- ┆ --- │ │ i64 ┆ u32 │ ╞═════╪═══════╡ │ 4 ┆ 2453 │ │ 5 ┆ 2492 │ └─────┴───────┘ I.e. the not-yet-exposed controls for $\hat{LATT}(2, 3)$ are evenly distributed across the cohorts $E \in \{4, 5\}$.