# pyright: basic
'''Optional snapshot/restore helpers for pysm state machines.'''
from collections import deque
from .pysm import StateMachine, StateMachineException
SNAPSHOT_VERSION = 1
[docs]
def snapshot(machine, metadata=None):
'''Return a plain-data snapshot of a configured state machine graph.'''
root = machine.root_machine
_validate_unique_sibling_names(root)
machines = _iter_machines(root)
data = {
'version': SNAPSHOT_VERSION,
'root': _state_path(root),
'states': [_state_path(state) for state in _iter_states(root)],
'machines': [],
'leaf_state': _state_path(root.leaf_state)
if root.leaf_state is not None else None,
'leaf_state_stack': [
_state_path(state) for state in root.leaf_state_stack.deque
],
}
if metadata is not None:
data['metadata'] = metadata
for item in machines:
data['machines'].append({
'path': _state_path(item),
'state': _state_path(item.state) if item.state is not None else None,
'state_stack': [
_state_path(state) for state in item.state_stack.deque
],
})
data['states'].sort()
data['machines'].sort(key=lambda item: item['path'])
return data
[docs]
def restore(machine, data):
'''Restore ``machine`` from a snapshot created by :func:`snapshot`.
The machine graph must already be constructed and initialized. Restore is
strict about topology so stale snapshots fail loudly instead of silently
restoring to the wrong state.
'''
root = machine.root_machine
if data.get('version') != SNAPSHOT_VERSION:
raise StateMachineException('Unsupported snapshot version: {0}'.format(
data.get('version')))
if data.get('root') != _state_path(root):
raise StateMachineException('Snapshot root does not match machine root')
_validate_topology(root, data)
_validate_active_state_paths(root, data)
for machine_data in data.get('machines', []):
item = _resolve_machine_path(root, machine_data['path'])
state_path = machine_data.get('state')
item.state = (_resolve_state_path(root, state_path)
if state_path is not None else None)
_replace_stack(
item.state_stack,
[_resolve_state_path(root, path)
for path in machine_data.get('state_stack', [])])
leaf_path = data.get('leaf_state')
root._leaf_state = (_resolve_state_path(root, leaf_path)
if leaf_path is not None else None)
_replace_stack(
root.leaf_state_stack,
[_resolve_state_path(root, path)
for path in data.get('leaf_state_stack', [])])
return machine
def _iter_states(root):
queue = deque([root])
while queue:
state = queue.popleft()
yield state
if isinstance(state, StateMachine):
children = sorted(state.states, key=lambda item: item.name)
queue.extend(children)
def _iter_machines(root):
return [state for state in _iter_states(root)
if isinstance(state, StateMachine)]
def _state_path(state):
path = []
item = state
while item is not None:
path.append(item.name)
item = item.parent
return list(reversed(path))
def _resolve_machine_path(root, path):
state = _resolve_state_path(root, path)
if not isinstance(state, StateMachine):
raise StateMachineException(
'Path does not point to a state machine: {0}'.format(path))
return state
def _resolve_state_path(root, path):
if not path:
raise StateMachineException('State path cannot be empty')
if path[0] != root.name:
raise StateMachineException(
'State path {0} does not start at root {1}'.format(
path, root.name))
state = root
for name in path[1:]:
if not isinstance(state, StateMachine):
raise StateMachineException(
'State path descends through non-machine state: {0}'.format(
path))
matches = [child for child in state.states if child.name == name]
if not matches:
raise StateMachineException('Unknown state path: {0}'.format(path))
if len(matches) > 1:
raise StateMachineException('Ambiguous state path: {0}'.format(
path))
state = matches[0]
return state
def _replace_stack(stack, states):
maxlen = getattr(stack.deque, 'maxlen', StateMachine.STACK_SIZE)
stack.deque = deque(maxlen=maxlen)
for state in states:
stack.push(state)
def _validate_topology(root, data):
expected_states = sorted(data.get('states', []))
actual_states = sorted(_state_path(state) for state in _iter_states(root))
if expected_states != actual_states:
raise StateMachineException('Snapshot topology does not match machine')
expected_machines = sorted(item['path']
for item in data.get('machines', []))
actual_machines = sorted(_state_path(item) for item in _iter_machines(root))
if expected_machines != actual_machines:
raise StateMachineException('Snapshot machines do not match machine')
_validate_unique_sibling_names(root)
def _validate_active_state_paths(root, data):
machines = {}
for machine_data in data.get('machines', []):
path = tuple(machine_data['path'])
machine = _resolve_machine_path(root, machine_data['path'])
machines[path] = machine_data
state_path = machine_data.get('state')
if state_path is not None:
state = _resolve_state_path(root, state_path)
if state.parent is not machine:
raise StateMachineException(
'Snapshot state is not a child of machine: {0}'.format(
state_path))
for stack_path in machine_data.get('state_stack', []):
state = _resolve_state_path(root, stack_path)
if state.parent is not machine:
raise StateMachineException(
'Snapshot stack state is not a child of machine: {0}'
.format(stack_path))
expected_leaf = _expected_leaf_path_from_machine_states(
tuple(data['root']), machines)
actual_leaf = data.get('leaf_state')
if ((actual_leaf is None and expected_leaf is not None) or
(actual_leaf is not None and tuple(actual_leaf) != expected_leaf)):
raise StateMachineException(
'Snapshot leaf state does not match active machine states')
def _expected_leaf_path_from_machine_states(root_path, machines):
machine_data = machines[root_path]
state_path = machine_data.get('state')
if state_path is None:
return None
leaf_path = tuple(state_path)
while leaf_path in machines:
child_data = machines[leaf_path]
child_state_path = child_data.get('state')
if child_state_path is None:
break
leaf_path = tuple(child_state_path)
return leaf_path
def _validate_unique_sibling_names(root):
for state in _iter_states(root):
if not isinstance(state, StateMachine):
continue
names = {}
for child in state.states:
names.setdefault(child.name, 0)
names[child.name] += 1
duplicates = [name for name, count in names.items() if count > 1]
if duplicates:
raise StateMachineException(
'Ambiguous sibling state names under {0}: {1}'.format(
_state_path(state), duplicates))