A Model ties together regimes, an age grid, and a regime ID class into a solvable
lifecycle model.
The Model Constructor¶
from lcm import Model
model = Model(
regimes=regimes, # dict mapping names to Regime instances
ages=ages, # AgeGrid defining the lifecycle timeline
regime_id_class=RegimeId, # @categorical dataclass mapping names to ScalarInt indices
enable_jit=True, # controls JAX compilation (default: True)
fixed_params={}, # optional params baked in at init time
description="", # optional description string
)All arguments are keyword-only. The three required arguments are regimes, ages, and
regime_id_class. The finalized regimes are stored as model.user_regimes (plain
Regime instances in user vocabulary); the processed canonical form is the
engine-internal model._regimes.
Model-Level Regime Slots¶
When several regimes share functions, states, or actions, declare the shared structure once at the model level instead of repeating it per regime — a lifecycle model with a couple of dozen shared functions and a handful of shared states shrinks to one declaration site:
model = Model(
regimes={"working": working, "retired": retired, "dead": dead},
ages=ages,
regime_id_class=RegimeId,
functions={"taxes": taxes, "net_income": net_income},
constraints={"budget": budget_constraint},
states={"wealth": LinSpacedGrid(start=1, stop=100, n_points=50)},
state_transitions={"wealth": next_wealth},
actions={"consumption": LinSpacedGrid(start=1, stop=50, n_points=30)},
)Each model-level slot accepts exactly what the regime-level slot accepts — including
Phased, stochastic processes, per-target dicts, and fixed_transition. The entries
are merged into every regime under three rules:
Exactly one level. A name is defined at model level or regime level, never both — a duplicate raises an ambiguity error at model build.
Nonemasks. A regime-levelNoneremoves the model entry for that regime (masking a state also drops its broadcast law of motion). ANonewith no model-level entry behind it is an error.DAG pruning. Broadcast states and actions are pruned per regime by reachability: a broadcast variable survives in a regime only if a root computation (utility,
H, constraints, derived categoricals, the regime transition, or a law of motion toward a reachable target that keeps the state) transitively reads it, in either phase. Regime-level declarations are never pruned.model.pruned_variablesrecords, per regime, which broadcast names were pruned.
Pruning means a model-level state costs nothing in regimes that never touch it — the
grid axis simply does not appear there. Two restrictions keep the device layout
coherent: distributed=True (sharding) is legal only on model-level states, and a
sharded state pruned from a non-terminal regime is an error (unshard it or make the
regime use it).
Regime ID Classes¶
The regime_id_class maps regime names to integer indices. Use the @categorical
decorator to create it:
from lcm import categorical
from lcm.typing import ScalarInt
@categorical(ordered=False)
class RegimeId:
retired: ScalarInt
working: ScalarIntRules:
Fields must be annotated as
ScalarInt— the 0-djnp.int32scalar pylcm produces for category codes. Other annotations raiseCategoricalDefinitionErrorat decoration time.Fields must match the keys of the
regimesdict exactly (sorted alphabetically).Values are auto-assigned as consecutive
jnp.int32scalars starting from 0.Use
RegimeId.working(class attribute access) to reference regime IDs in transition functions.
Age Grids¶
The ages argument defines the lifecycle timeline. There are two construction modes:
Range-based¶
from lcm import AgeGrid
ages = AgeGrid(start=25, stop=75, step="Y") # annual steps, ages 25 to 75Step formats:
"Y"— 1 year"2Y"— 2 years"Q"— quarter (0.25 years)"M"— month (1/12 year)"3M"— 3 months
The stop value is inclusive if (stop - start) is exactly divisible by the step size.
Exact values¶
ages = AgeGrid(exact_values=[25, 35, 45, 55, 65, 75])Use this for irregular age spacing.
Key properties¶
ages.values— JAX array of ages, indexed by periodages.n_periods— number of periodsages.step_size— step size in years (orNonefor exact values)ages.period_to_age(period)— convert period index to ageages.get_periods_where(predicate)— get periods matching a condition
Model Validation Rules¶
The Model constructor validates:
At least one terminal regime and one non-terminal regime must be provided.
Regime names cannot contain
__(reserved separator).regime_id_classfields must exactly match theregimesdict keys.All states and actions must be used by at least one function (utility, constraints, or transitions).
The age grid must have at least 2 periods.
Inspecting a Model¶
After construction, the model exposes several useful attributes:
model.user_regimes # immutable mapping of finalized `Regime` objects
model.pruned_variables # per regime, the broadcast names pruned by DAG reachability
model.n_periods # number of periods
model.regime_names_to_ids # name -> integer mapping
model.get_params_template() # mutable copy of the parameter templateUse model.get_params_template() to get a mutable copy of the parameter template — see
Parameters.
Complete Example¶
import jax.numpy as jnp
from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical
from lcm.typing import ScalarInt
@categorical(ordered=False)
class RegimeId:
retired: ScalarInt
working: ScalarInt
@categorical(ordered=True)
class LaborSupply:
do_not_work: ScalarInt
work: ScalarInt
def next_wealth(wealth, consumption, interest_rate):
return (wealth - consumption) * (1 + interest_rate)
def next_regime(labor_supply):
return jnp.where(
labor_supply == LaborSupply.work, RegimeId.working, RegimeId.retired
)
def utility(consumption, labor_supply, disutility_of_work):
return jnp.log(consumption) - disutility_of_work * labor_supply
def terminal_utility(wealth):
return jnp.log(wealth)
working = Regime(
transition=next_regime,
states={
"wealth": LinSpacedGrid(start=1, stop=100, n_points=50),
},
state_transitions={
"wealth": next_wealth,
},
actions={
"consumption": LinSpacedGrid(start=1, stop=50, n_points=30),
"labor_supply": DiscreteGrid(LaborSupply),
},
functions={"utility": utility},
)
retired = Regime(
transition=None,
states={
"wealth": LinSpacedGrid(start=1, stop=100, n_points=50),
},
functions={"utility": terminal_utility},
)
model = Model(
regimes={"working": working, "retired": retired},
ages=AgeGrid(start=25, stop=75, step="Y"),
regime_id_class=RegimeId,
)See Also¶
Writing Economics — function DAGs and regime design
Regimes — detailed guide to defining regimes
Parameters — constructing the params dict
Solving and Simulating — running the model