Skip to content

Commit c230df9

Browse files
shadajsaeta
andauthored
Add documents describing the modern layer API requirements and prototypes (tensorflow#515)
Co-authored-by: Brennan Saeta <saeta@google.com>
1 parent f977964 commit c230df9

File tree

2 files changed

+281
-0
lines changed

2 files changed

+281
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Layer API Design Space
2+
## Overview
3+
This document describes the high-level requirements and tradeoffs for a successful layer API design. It also walks through a couple conceptual approaches for API designs.
4+
5+
## Definitions
6+
Neural Network - collection of weight values and a differentiable execution function
7+
8+
Layer - a differentiable function from (weights & input) to output, and a default function from hyperparameters to weights
9+
10+
Initialization Function - an optionally differentiable function mapping hyperparameters of a layer to a new instance of the weights or an existing weight store that has been mutated
11+
12+
Model - a mapping of a set of layers to a set of weight instances (bipartite; many-to-many)
13+
14+
## API Requirements
15+
### Layer Composition
16+
- Any layer, combination of layers, or trained model, should be usable as a layer in another model
17+
- Scales to complex architectures, no need to rewrite the model to use a different API for advanced graph types
18+
- No boilerplate:
19+
- No duplicate shapes
20+
- No duplicate input/output/scalar types (use type inference)
21+
- No redundant definition of weights and execution in default use cases
22+
23+
### Complex Architectures
24+
- Skip-connections (use results of a layer multiple times)
25+
- Shared layers (reuse weights multiple times, usually but not always with the same execution function but at different points of the graph)
26+
- Support dynamic architectures, with generated layers and connections based on runtime configuration
27+
- Also support reconfiguration of models to use different hyperparameters at runtime
28+
29+
### State Management
30+
- Weight access should be type-safe (no casting into the specific type)
31+
- All weights should have associated names (variables, subscripts) and can use those names to access the current value
32+
- Weights should be groupable for advanced optimizers (e.g. that use multiple learning rates) or partially "freezing" a model.
33+
- Weights should be loadable from checkpoint files and support mapping weights from equivalent models
34+
- Weight manipulation should be handled in a value-semantic way to prevent unexpected changes
35+
36+
### Execution Debugging
37+
- Access to the values of intermediate tensors within the graph (inputs/outputs to layers)
38+
- Not stored by default, should be opt-in
39+
- Insert debugging “layers” (e.g. that print out their input, or compute arbitrary other data-dependent information)
40+
- Display the final model architecture in a graphical format
41+
42+
### Type-Safety
43+
- No stringly-typed APIs
44+
- access to weights/intermediate tensors must be type-safe
45+
- Rank-safe computation - track the number of dimensions of data
46+
- Track the meaning of each channel (differentiate “CHW” vs “HWC” images)
47+
- All other opportunities that are reasonably accessible
48+
49+
## Design Approaches
50+
### Weights vs. Execution vs. Layers
51+
One of the key insights resulting from our discussions was the separation between weights and model execution. In the current S4TF layer API, these are combined by defining layers as stateful functions which both capture the current weight values and define how the layer should be applied to input data. While this works well for simple systems that have a bijection between weights and execution functions, this is harder to adapt to systems that require the same layer with the same weights to be applied at multiple locations in the model (layer sharing). Implementing such architectures with packaged weights and execution results in referential semantics since multiple nodes in the graph would need to refer to the same underlying layer.
52+
53+
If we take a more functional approach, however, where weights do not make up the state of a function but instead are just an additional parameter for execution, this becomes more straightforward to handle. Instead of having to refer to a shared mutable state of a specific layer, the execution functions instead take the entire set of weights of the model and use the weights that are relevant to the current node. As a result, we effectively bubble up individual weights to the model level and eliminate the referential semantics needed when execution functions are tied to mutable state.
54+
55+
When separating weights from execution, however, we must be careful to not introduce boilerplate that forces duplicate definitions of weights and execution when one can be inferred from the other. Although layers are not a core component of the final model, they can exist as helpers that associate weights with default execution functions in order to eliminate boilerplate.
56+
57+
### Explicit Graphs vs Layers as Tensors
58+
59+
In our prototypes for layer APIs, we settled on two primary strategies for combining layers into more complex architectures: building an explicit graph or inferring the graph from layer dependencies.
60+
61+
In the first strategy, the user directly builds the entire graph in a way that requires no additional computation to determine the dependents of any layer. This requires the user to be aware of both incoming and outgoing edges from every layer “node” when constructing the model, but makes it easy to get high performance since there is a direct mapping from the composed layers to the weights and functions to execute. For example, when implementing a skip connection, the user would need to specify both a “fan-out” for the layer whose results will be used along multiple paths as well as a “fan-in” to combine the results.
62+
63+
For a simpler user experience, we can infer the graph based on dependencies, which eliminates the need to specify the “fan-out” of skip connections since we can detect layers that have multiple dependents. When designing models with this style, every layer tracks its dependencies. In this way, users can manipulate layers just like lazy tensors, since they accumulate a trace of dependencies and can be used as dependencies of other layers to define connections.

0 commit comments

Comments
 (0)