Examples¶
This page collects a few slightly less standard use cases that build on the main workflow from the quickstart and aggregation pages.
Custom Group Aggregation¶
You can aggregate cohorts into custom groups by passing a custom_map
to agg_latt(..., method="cohort_custom").
1import idid
2
3df = idid.sim_stag_panel(
4 n=10_000,
5 T=5,
6 E_cohorts=[0, 2, 3, 4, 5],
7)
8
9res = idid.estimate(
10 df,
11 cohort="E",
12 time="t",
13 outcome="Y_t",
14 treatment="D_t",
15 unit="id",
16 covariates=["X"],
17 control="never",
18 method="dr",
19 balanced=True,
20 verbose=False,
21)
22
23custom_map = {
24 2: "2-3",
25 3: "2-3",
26 4: "4-5",
27 5: "4-5",
28}
29
30agg = idid.agg_latt(
31 res,
32 method="cohort_custom",
33 agg_kwargs={"custom_map": custom_map},
34)
35agg.summary()
Overall summary of ATT's based on custom cohort aggregation:
LATT Std. Error [95% Conf. Band]
1.1750 0.1013 0.9765 1.3734 *
Custom cohort effects:
Cohort Estimate Std. Error [95% Pointwise Conf. Band]
2-3 0.9925 0.1214 0.7545 1.2304 *
4-5 1.3506 0.1584 1.0402 1.6610 *
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
Compare that to:
1agg = idid.agg_latt(res, method="cohort")
2agg.summary()
Overall summary of ATT's based on cohort aggregation:
LATT Std. Error [95% Conf. Band]
1.1522 0.0992 0.9577 1.3466 *
Cohort effects:
Cohort Estimate Std. Error [95% Pointwise Conf. Band]
2 0.9999 0.1625 0.6813 1.3185 *
3 0.9818 0.1793 0.6304 1.3331 *
4 1.5078 0.2244 1.0680 1.9477 *
5 1.0982 0.2087 0.6891 1.5072 *
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
Group Difference Aggregation¶
To estimate differences between two groups, estimate the cohort-time
LATTs within each group and then call group_diff_idid. The true
difference between the groups equals \(1/2\) for all \((e, t)\).
1import polars as pl
2
3from idid.aggregate import group_diff_idid
4from idid.plotting import plot_agg, plot_evolution_groups, summarize_evolution
5
6df = idid.sim_stag_panel(
7 n=20_000,
8 T=4,
9 E_cohorts=[0, 2, 3, 4],
10 confounded=True,
11 with_group=True,
12)
13print(df.head())
shape: (5, 7)
┌─────┬─────┬─────┬───────────┬─────┬──────────┬─────┐
│ id ┆ E ┆ t ┆ X ┆ D_t ┆ Y_t ┆ F │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ f64 ┆ i64 ┆ f64 ┆ i64 │
╞═════╪═════╪═════╪═══════════╪═════╪══════════╪═════╡
│ 0 ┆ 4 ┆ 1 ┆ 1.285605 ┆ 1 ┆ 7.275993 ┆ 1 │
│ 0 ┆ 4 ┆ 2 ┆ 1.285605 ┆ 0 ┆ 2.930226 ┆ 1 │
│ 0 ┆ 4 ┆ 3 ┆ 1.285605 ┆ 0 ┆ 3.64175 ┆ 1 │
│ 0 ┆ 4 ┆ 4 ┆ 1.285605 ┆ 1 ┆ 7.260314 ┆ 1 │
│ 1 ┆ 4 ┆ 1 ┆ -0.303553 ┆ 1 ┆ 3.297906 ┆ 0 │
└─────┴─────┴─────┴───────────┴─────┴──────────┴─────┘
1gp0 = summarize_evolution(df.filter(pl.col("F").eq(0)))
2gp1 = summarize_evolution(df.filter(pl.col("F").eq(1)))
3fig, ax = plot_evolution_groups(
4 gp0,
5 gp1,
6 include_bands=True,
7)

1ests = {}
2for f in [0, 1]:
3 res = idid.estimate(
4 df.filter(pl.col("F").eq(f)).with_columns(
5 id=pl.lit("M" if f == 0 else "F") + pl.col("id").cast(pl.Utf8),
6 ),
7 cohort="E",
8 time="t",
9 outcome="Y_t",
10 treatment="D_t",
11 unit="id",
12 covariates=["X"],
13 control="never",
14 method="dr",
15 balanced=True,
16 verbose=False,
17 )
18 ests[f] = res
19
20diff = group_diff_idid(ests[1], ests[0])
21
22print(diff.estimates)
23diff.summary()
shape: (6, 7)
┌─────┬─────┬───────────┬──────────┬──────────┬───────────┬───────────┐
│ E ┆ t ┆ latt ┆ denom ┆ se ┆ lower ┆ upper │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════╪═════╪═══════════╪══════════╪══════════╪═══════════╪═══════════╡
│ 2 ┆ 2 ┆ -0.713377 ┆ 0.188419 ┆ 0.486029 ┆ -1.665993 ┆ 0.239239 │
│ 2 ┆ 3 ┆ -0.721623 ┆ 0.19596 ┆ 0.452582 ┆ -1.608683 ┆ 0.165437 │
│ 2 ┆ 4 ┆ -0.831864 ┆ 0.211267 ┆ 0.415799 ┆ -1.64683 ┆ -0.016897 │
│ 3 ┆ 3 ┆ -0.747214 ┆ 0.214008 ┆ 0.41707 ┆ -1.564672 ┆ 0.070243 │
│ 3 ┆ 4 ┆ -0.991648 ┆ 0.225592 ┆ 0.391804 ┆ -1.759584 ┆ -0.223713 │
│ 4 ┆ 4 ┆ -0.591208 ┆ 0.211862 ┆ 0.417223 ┆ -1.408965 ┆ 0.226549 │
└─────┴─────┴───────────┴──────────┴──────────┴───────────┴───────────┘
Cohort-Time Local Average Treatment Effects on the Treated:
E t AET(e, t) LATT(e, t) Std. Error [95% Pointwise. Conf. Band]
2 2 0.1884 -0.7134 0.4860 -1.6660 0.2392
2 3 0.1960 -0.7216 0.4526 -1.6087 0.1654
2 4 0.2113 -0.8319 0.4158 -1.6468 -0.0169 *
3 3 0.2140 -0.7472 0.4171 -1.5647 0.0702
3 4 0.2256 -0.9916 0.3918 -1.7596 -0.2237 *
4 4 0.2119 -0.5912 0.4172 -1.4090 0.2265
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
Aggregated dynamic effects:
1agg = idid.agg_latt(diff, method="dynamic")
2agg.summary()
Overall summary of ATT's based on event-study/dynamic aggregation:
LATT Std. Error [95% Conf. Band]
-0.7932 0.2551 -1.2933 -0.2932 *
Dynamic effects:
Event time Estimate Std. Error [95% Pointwise Conf. Band]
0 -0.6834 0.2080 -1.0910 -0.2758 *
1 -0.8644 0.2980 -1.4485 -0.2802 *
2 -0.8319 0.4158 -1.6468 -0.0169 *
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
1fig, ax = plot_agg([agg])

1agg_diff = idid.agg_latt(diff, method="dynamic", boot=True)
2agg_diff.summary()
Overall summary of ATT's based on event-study/dynamic aggregation:
LATT Std. Error [95% Simult. Conf. Band]
-0.7932 0.2619 -1.3063 -0.2802 *
Dynamic effects:
Event time Estimate Std. Error [95% Simult. Conf. Band]
0 -0.6834 0.2119 -1.1694 -0.1975 *
1 -0.8644 0.3044 -1.5624 -0.1663 *
2 -0.8319 0.4086 -1.7690 0.1053
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
Multiplier bootstrap: B=1000, c=2.2934, overall c=1.9588
Compare that to:
1agg0 = idid.agg_latt(ests[0], method="dynamic", boot=True)
2agg0.summary()
Overall summary of ATT's based on event-study/dynamic aggregation:
LATT Std. Error [95% Simult. Conf. Band]
1.1506 0.1841 0.7958 1.5054 *
Dynamic effects:
Event time Estimate Std. Error [95% Simult. Conf. Band]
0 1.0224 0.1550 0.6773 1.3675 *
1 1.2295 0.2259 0.7265 1.7325 *
2 1.2000 0.2980 0.5366 1.8633 *
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
Multiplier bootstrap: B=1000, c=2.2264, overall c=1.9274
1agg1 = idid.agg_latt(ests[1], method="dynamic", boot=True)
2agg1.summary()
Overall summary of ATT's based on event-study/dynamic aggregation:
LATT Std. Error [95% Simult. Conf. Band]
0.3494 0.1739 -0.0236 0.7224
Dynamic effects:
Event time Estimate Std. Error [95% Simult. Conf. Band]
0 0.3214 0.1408 0.0042 0.6387 *
1 0.3587 0.2134 -0.1220 0.8394
2 0.3681 0.2922 -0.2901 1.0263
---
Signif. codes: `*' confidence band does not cover 0
Control group: Never treated
Estimation Method: Doubly Robust
Multiplier bootstrap: B=1000, c=2.2528, overall c=2.1445
We can plot the three aggregated objects:
1fig, ax = plot_agg(
2 [agg0, agg1, agg_diff],
3 labels=[
4 "F = 0",
5 "F = 1",
6 "Group Difference",
7 ],
8)

Introspection of estimates¶
The implementation allows for inspection of specific \(LATT(e, t)\)
estimates. After calling idid.estimate, the cell-level results are
stored in the latts dictionary of the returned object. This is useful
when you want to inspect one particular effect and see which treated and
control observations entered that comparison.
1from idid._types import FailedLATT
2from idid.estimators import get_controls
3
4res = idid.estimate(
5 idid.sim_stag_panel(
6 n=10_000,
7 T=5,
8 E_cohorts=[2, 3, 4, 5],
9 ),
10 cohort="E",
11 time="t",
12 outcome="Y_t",
13 treatment="D_t",
14 unit="id",
15 covariates=["X"],
16 control="notyet",
17 method="dr",
18 balanced=True,
19 verbose=False,
20)
21
22e = 2
23t = 3
24latt = res.latts[(e, t)]
25if isinstance(latt, FailedLATT):
26 raise ValueError("Failed LATT estimation")
27
28print(latt)
LATT(g=2, t=3, latt=float64(), num=float64(), denom=float64(), ns=7417, ids=DataFrame[7417x2; id, E], IF=ndarray(7417, 1), IF_aet=ndarray(7417, 1), extra=dict[num, den])
Each LATT object stores the underlying unit-period ids used for that
comparison. In the ids data frame, E = 1 denotes treated
observations and E = 0 denotes controls.
1print(latt.ids.head())
2print(latt.ids["E"].value_counts())
shape: (5, 2)
┌─────┬─────┐
│ id ┆ E │
│ --- ┆ --- │
│ i64 ┆ i8 │
╞═════╪═════╡
│ 1 ┆ 1 │
│ 2 ┆ 0 │
│ 3 ┆ 0 │
│ 4 ┆ 0 │
│ 5 ┆ 0 │
└─────┴─────┘
shape: (2, 2)
┌─────┬───────┐
│ E ┆ count │
│ --- ┆ --- │
│ i8 ┆ u32 │
╞═════╪═══════╡
│ 0 ┆ 4945 │
│ 1 ┆ 2472 │
└─────┴───────┘
A small helper returns the control units for a given LATT object and
IDidResult.
1controls = get_controls(latt, res).sort("id")
2print(controls.head())
shape: (5, 2)
┌─────┬─────┐
│ id ┆ E │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 2 ┆ 5 │
│ 3 ┆ 4 │
│ 4 ┆ 5 │
│ 5 ┆ 4 │
│ 6 ┆ 5 │
└─────┴─────┘
1print(f"#Controls = {controls.shape[0]}")
2print(controls[res.dp.e_col].value_counts().sort(res.dp.e_col))
#Controls = 4945
shape: (2, 2)
┌─────┬───────┐
│ E ┆ count │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞═════╪═══════╡
│ 4 ┆ 2453 │
│ 5 ┆ 2492 │
└─────┴───────┘
I.e. the not-yet-exposed controls for \(\hat{LATT}(2, 3)\) are evenly distributed across the cohorts \(E \in \{4, 5\}\).