on residual streams
i spend all day making neural networks run fast. i don't really understand what they're doing.
that's not a contradiction. optimization doesn't require understanding. i need to know the computational graph, which operators run in what order, where the memory accesses are, where the bottlenecks live. i don't need to know what the model "thinks." i can make it 30% faster without understanding what any individual neuron does.
this is fine for my job. it's not fine for my curiosity.
i've been reading mechanistic interpretability work for the past few months. the framing that got me hooked is the "residual stream" view of transformers, from Elhage et al.'s "A Mathematical Framework for Transformer Circuits" (Anthropic, 2021).
the basic idea: in a transformer, there's a vector that flows through the network and gets added to at each layer. attention heads read from it and write to it. MLP layers read from it and write to it. the residual stream is the central communication channel. everything in the network communicates through this shared workspace.
what makes this interesting is the reframing. instead of "layer 1 processes the input, then layer 2 processes layer 1's output," it becomes: "there's a shared workspace, and each component contributes something." different attention heads might be doing completely different jobs. some track syntactic structure, some copy information forward from earlier positions, some do basic pattern matching. they all read and write to the same stream.
Anthropic pushed this further with their "Scaling Monosemanticity" work in 2024. they trained sparse autoencoders to decompose the residual stream of Claude 3 Sonnet into interpretable directions, features that correspond to recognizable concepts. they found features for things like the Golden Gate Bridge, code written in Python, deceptive behavior. individual neurons are polysemantic (they activate for multiple unrelated things), but these learned directions are closer to monosemantic. one feature, one concept. or at least closer to it.
here's what i keep thinking about: what would we find if we did this to a driving model?
the FSD network takes camera images and outputs driving commands. somewhere between input and output, it has to represent things like "there's a cyclist in my blind spot" and "this intersection has no protected left" and "the car ahead is decelerating but hasn't hit the brakes yet." are those represented as clean features in the residual stream? or is it more tangled, distributed across dimensions in ways we can't easily decompose?
i don't know. as far as i can tell, nobody's published serious mechanistic interpretability work on end-to-end driving models. the models are proprietary, the stakes are high, and the input modality (multi-camera video) makes the analysis harder than text.
but the safety case is obvious. if you could find a "this situation is dangerous" feature in the driving model's residual stream, and verify it fires reliably, that would mean something concrete. more than any aggregate metric.
the thing that nags at me: i can profile every microsecond of this model's execution. i can tell you the exact memory access pattern of layer 37. but i can't tell you what layer 37 "knows." i have total visibility into how the computation happens and zero visibility into what it means.
there's a version of my career where i go try to figure that out. i think about it more than i probably should.