API Reference#
Complete API documentation for deep-inference.
API Reference
Quick Reference#
Two APIs#
API |
Use Case |
|---|---|
|
Production, 13 families, fixed target E[β] |
|
Flexible targets, regime detection, RCT support |
Main Entry Points#
from deep_inference import structural_dml, inference
# Legacy API (production-ready)
result = structural_dml(
Y=Y, T=T, X=X,
family='linear',
hidden_dims=[64, 32],
epochs=100,
n_folds=50
)
# New API (flexible)
from deep_inference.lambda_.compute import Normal
result = inference(
Y=Y, T=T, X=X,
model='logit',
target='ame', # Flexible target
is_randomized=True, # Regime A
treatment_dist=Normal(0, 1)
)
Available Families#
from deep_inference import FAMILY_REGISTRY
print(list(FAMILY_REGISTRY.keys()))
# ['linear', 'logit', 'poisson', 'tobit', 'negbin', 'gamma', 'gumbel', 'weibull',
# 'gaussian', 'probit', 'beta', 'zip', 'multinomial_logit']
Family Classes#
from deep_inference import (
LinearFamily, LogitFamily, PoissonFamily, TobitFamily,
NegBinFamily, GammaFamily, GumbelFamily, WeibullFamily,
MultinomialLogitFamily,
)
from deep_inference.families import (
GaussianFamily, ProbitFamily, BetaFamily, ZIPFamily,
)
Module Overview#
structural_dml#
The main entry point. Trains a structural neural network with influence function-based inference.
from deep_inference import structural_dml
result = structural_dml(
Y, # Outcome variable (n,)
T, # Treatment variable (n,)
X, # Covariates (n, d)
family='linear', # Statistical family
hidden_dims=[64, 32], # Network architecture
epochs=100, # Training epochs
n_folds=50, # Cross-fitting folds
lr=0.01, # Learning rate
batch_size=64, # Mini-batch size
weight_decay=1e-4, # L2 regularization
verbose=False # Print progress
)
DMLResult#
The result object returned by structural_dml:
Attribute |
Description |
|---|---|
|
Debiased point estimate of E[beta(X)] |
|
Naive (biased) estimate |
|
Standard error |
|
Lower bound of 95% CI |
|
Upper bound of 95% CI |
|
Estimated parameters (n, theta_dim) |
|
Influence scores (n,) |
|
Dict with training diagnostics |
families#
Statistical families defining loss functions, gradients, Hessians, and influence scores.
targets#
Target functionals for inference: AverageParameter, AME, CustomTarget.
lambda#
Lambda estimation strategies: ComputeLambda (Regime A), AnalyticLambda (B), EstimateLambda (C).
models#
Neural network architectures: StructuralNet for parameter estimation.
metrics#
Helper functions for computing coverage and SE ratios.