Inference API#
This module provides two main functions for structural estimation with valid inference.
API Overview#
Function |
Use Case |
Target |
Lambda |
|---|---|---|---|
|
Production, 13 families |
E[β(X)] fixed |
Estimated |
|
Flexible targets, regimes |
Custom h(θ) |
Auto-selected |
structural_dml() - Legacy API#
The production-ready API supporting 13 GLM families.
Signature#
from deep_inference import structural_dml
result = structural_dml(
Y, # (n,) outcomes
T, # (n,) treatments
X, # (n, d) covariates
family='linear', # Family name: 'linear', 'logit', 'poisson', etc.
target=None, # Target variant (e.g., 'ame' for logit)
hidden_dims=[64, 32], # Network architecture
epochs=100, # Training epochs
n_folds=50, # Cross-fitting folds
lr=0.01, # Learning rate
lambda_method='ridge', # Lambda estimation: ridge (default), aggregate, lgbm
verbose=False # Print progress
)
Supported Families#
Family |
Model |
θ_dim |
Notes |
|---|---|---|---|
|
Y = α + βT + ε |
2 |
OLS-equivalent |
|
P(Y=1) = σ(α + βT) |
2 |
Binary outcomes |
|
Y ~ Pois(exp(α + βT)) |
2 |
Count data |
|
Y ~ Gamma(k, exp(α + βT)) |
2 |
Positive continuous |
|
Y ~ Gumbel(α + βT, σ) |
2 |
Extreme values |
|
Y = max(0, α + βT + σε) |
3 |
Censored |
|
Y ~ NegBin(exp(α + βT), r) |
2 |
Overdispersed counts |
|
Y ~ Weibull(k, exp(α + βT)) |
2 |
Survival/duration |
|
P(Y=j) = softmax(α_j + X’_j β) |
(J-1)+K |
Discrete choice (J>=3) |
Example#
import numpy as np
from deep_inference import structural_dml
# Generate data
np.random.seed(42)
n = 2000
X = np.random.randn(n, 10)
T = np.random.randn(n)
Y = X[:, 0] + 0.5 * T + np.random.randn(n)
# Run inference
result = structural_dml(
Y=Y, T=T, X=X,
family='linear',
n_folds=50,
epochs=100
)
print(result.summary())
Output:
==============================================================================
Structural DML Results
==============================================================================
Family: Linear Target: E[beta]
No. Observations: 2000 No. Folds: 50
==============================================================================
coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
E[beta] 0.5012 0.0234 21.410 0.000 0.4553 0.5471
==============================================================================
Diagnostics:
Min Lambda eigenvalue: 0.998765
Mean condition number: 1.00
Correction ratio: 0.0523
------------------------------------------------------------------------------
Viewing Results with summary()#
Both DMLResult and InferenceResult provide a summary() method that produces a statsmodels-style table with:
Header: Family, target, sample size, number of folds, date/time
Coefficient table: Estimate, standard error, z-statistic, p-value, 95% CI
Diagnostics: Lambda eigenvalue, condition number, correction ratio
# Get formatted summary string
print(result.summary())
# Individual components are still accessible
print(result.mu_hat) # Point estimate
print(result.se) # Standard error
print(result.ci_lower) # Lower 95% CI
print(result.ci_upper) # Upper 95% CI
DMLResult Object#
Attribute |
Type |
Description |
|---|---|---|
|
float |
Debiased estimate of E[β(X)] |
|
float |
Naive (biased) estimate |
|
float |
Standard error |
|
float |
Lower 95% CI bound |
|
float |
Upper 95% CI bound |
|
ndarray |
Estimated θ(x) for all observations |
|
ndarray |
Influence function values |
|
dict |
Training diagnostics |
inference() - New Flexible API#
The new API with flexible targets and automatic regime detection.
Signature#
from deep_inference import inference
result = inference(
Y, # (n,) outcomes
T, # (n,) treatments
X, # (n, d) covariates
# Model specification (choose one):
model='logit', # Built-in: 'linear', 'logit', 'multinomial_logit'
loss=None, # OR custom loss function
theta_dim=None, # Required if custom loss
# Target specification (choose one):
target='beta', # Built-in: 'beta', 'ame'
target_fn=None, # OR custom target function
t_tilde=None, # Evaluation point (default: mean(T))
# Regime settings:
is_randomized=False, # True for RCTs
treatment_dist=None, # Known F_T (e.g., Normal(0, 1))
lambda_method=None, # Override auto-detection
# Cross-fitting:
n_folds=50,
# Network:
hidden_dims=[64, 32],
epochs=100,
lr=0.01,
ridge=1e-4,
verbose=False
)
Built-in Targets#
Target |
Formula |
Use Case |
|---|---|---|
|
E[β(X)] |
Average treatment effect (log-odds for logit) |
|
E[p(1-p)β] |
Average marginal effect (probability scale) |
Custom Target Functions#
Define any target h(x, θ, t̃) and the Jacobian is computed via autodiff:
import torch
def my_target(x, theta, t_tilde):
"""Average prediction at treatment level t_tilde."""
alpha, beta = theta[0], theta[1]
return torch.sigmoid(alpha + beta * t_tilde)
result = inference(
Y, T, X,
model='logit',
target_fn=my_target,
t_tilde=0.0
)
Three Regimes#
Regime |
Condition |
Lambda Method |
Cross-Fitting |
|---|---|---|---|
A |
RCT + known F_T |
Compute (MC integration) |
2-way |
B |
Linear model |
Analytic (closed-form) |
2-way |
C |
Observational + nonlinear |
Estimate (neural net) |
3-way |
from deep_inference.lambda_.compute import Normal
# Regime A: Randomized experiment
result = inference(
Y, T, X,
model='logit',
target='beta',
is_randomized=True,
treatment_dist=Normal(mean=0.0, std=1.0)
)
print(f"Regime: {result.diagnostics['regime']}") # 'A'
InferenceResult Object#
Attribute |
Type |
Description |
|---|---|---|
|
float |
Point estimate |
|
float |
Standard error |
|
float |
Lower 95% CI |
|
float |
Upper 95% CI |
|
Tensor |
Influence function values |
|
Tensor |
Estimated θ(x) |
|
dict |
Regime, lambda method, etc. |
Configuration Guidelines#
Network Architecture#
Sample Size |
Recommended |
|---|---|
n < 1,000 |
|
1,000 - 10,000 |
|
10,000 - 100,000 |
|
n > 100,000 |
|
Cross-Fitting Folds#
Use Case |
K |
|---|---|
Quick exploration |
10-20 |
Production |
50 |
Very large data |
20-50 |
Lambda Method (for structural_dml)#
Method |
Coverage |
Default |
Notes |
|---|---|---|---|
|
96% |
Yes |
Safe default, validated coverage |
|
95% |
No |
Ignores X-dependence |
|
96% |
No |
High accuracy alternative |
|
67% |
No |
AVOID - invalid SEs |
Algorithm Overview#
Cross-Fitting (K-Fold)#
For k = 1 to K:
Train: Fit θ̂(x) on folds ≠ k
[If 3-way: Fit Λ̂(x) on separate fold]
Eval: Compute ψ on fold k
Aggregate: μ̂ = mean(ψ), SE = std(ψ)/√n
Influence Function#
The influence function corrects for regularization bias:
ψ(z) = H(θ̂) - H_θ · Λ(x)⁻¹ · ℓ_θ(z, θ̂)
Where:
H(θ) = target functional (e.g., E[β])
H_θ = Jacobian of target w.r.t. θ
Λ(x) = E[ℓ_θθ | X=x] = conditional Hessian
ℓ_θ = score (gradient of loss)
Expected Performance#
Method |
Coverage |
SE Ratio |
|---|---|---|
Naive |
~10-30% |
<< 1 |
Influence |
~95% |
~1.0 |