One of the most frequent questions I get asked from people exploring machine learning beyond cloud and desktop machines is “What about training?”. If you look around at the popular frameworks and use cases of edge ML, most of them seem focused on inference. It isn’t obvious why this is the case though, so I decided to collect my notes in a post here, so I can have something to refer to when this comes up (and organize my own thoughts too!).
I think the biggest reason that there’s not more training on the edge is that most models need to be trained through supervised learning, that is each sample used for training needs a ground truth label. If you’re running on a phone or embedded system, there’s not likely to be an easy way to attach a label to incoming data, other than running an existing model and guessing. You need a person to look at an image, or listen to an audio recording, to identify what the prediction should be, before you can use it in training. You also generally need a fairly large number of labels per class for training to be effective.
This may change as semi-supervised or unsupervised approaches continue to improve, but right now supervised training is the most reliable method to get a model for most applications. I have seen some interesting hacks to guess labels on the edge though, that might fall into the semi-supervised category. For example, you can use temporal consistency on video frames to infer mistakes. In concrete terms, if your camera is identifying a fruit as a lemon for ten frames, then for one frame it’s a lime, and then it’s back to a lemon, you can guess that the lime prediction was an error (assuming the frame rate is high enough, fruits aren’t flying by at supersonic speed, and so forth). Another clever use of time was in an audio wake word application, where if there was a near-detection (the model gave a score just below the threshold) followed soon after by an actual detection (over the threshold) then the system would guess that the person had actually said the wake word the first time, and the model had failed to recognize it. This hack relies on the human behavior of trying again if it didn’t work initially.
Getting models to work well within an application is very hard when you are training a single version and putting it through testing before release. If an edge model is retrained, it will be very hard to predict the bounds of its behavior. Since this will affect how well your application works, training on the fly makes ensuring it behaves correctly much harder. This isn’t a complete blocker, there are clearly some products (like GBoard) that do manage to handle this problem, but they generally build some kind of guard rails around what the model can produce. For example, something that predicts words or sentences might have a block-list of banned words (such as hateful or obscene phrases) that will be scrubbed from a model’s output even if edge training causes it to start producing them.
This kind of post-processing is often needed even when using pre-trained models on the edge (I could probably fill a decent book with all the hacks that usually go into filtering and interpreting the raw model output to make it useful) but the presence of a model that can change in unpredictable ways makes it even harder. Nobody wants to be responsible for building another Tay.
When you set up a new phone, you’ll probably speak the assistant wake word a few times to help the system learn your voice. In my experience this doesn’t involve retraining in the sense of full back propagation. Instead, the “Is this audio a wake word?” model produces an embedding vector as its output, and that is then used in a nearest-neighbor lookup to compare to the embeddings from the first few utterances you spoke during setup. This is a surprisingly common technique across a lot of domains, because it is comparatively simple to implement, only requires storing a few values, and works robustly.
I’ve found embeddings to be a fantastic general purpose tool for customizing models on the edge, without requiring the full machinery of back propagation. The gradient descent approach used by modern deep learning needs high precision (usually floating point) weight arrays, along with specialized operators to run the back-prop version of each layer. The weights need to be stored between updates, and since they’re higher precision than is required for inference they take up more space than an inference-optimized model, and you’ll usually want to keep a copy of the original weights around in case you need to reset the model too. By contrast, you can often extract an embedding from an existing model just by reading the activation layer before the final fully-connected op that does the classification. Even though specialized loss functions exist to try to encourage embeddings with desired properties, like good spatial separation, I’ve found that training with a regular softmax and lopping off the last layer often works just as well in practice.
Of course, there are examples of very successful products that do use training on the edge. I already mentioned GBoard, which is the poster child for federated learning, but another domain where I’ve seen a lot of use is in anomaly detection, particularly around predictive maintenance for machinery. This is an application where it seems like every machine behaves differently, so learning “normal behavior” (by observing the first 24 hours of vibrations and labeling those as normal) allows the adaptation needed to spot deviations from those initial patterns. I’ve also seen interesting research projects around security and communications protocols that are looking at using training on the edge to be more robust to changing environmental conditions.
The short answer to the question is that if you’re getting started with ML on the edge, training models there is unlikely to be useful in the short or medium term. Technology keeps changing, and I am seeing some interesting applications starting to emerge, but I feel like a lot of the interest in edge training comes from how prominent training is in the cloud world. I often joke that all ML architecture researchers could go on strike indefinitely, and ML engineers would still have decades of productive work ahead of us. There are many better-motivated problems around deployment on the edge than bringing training up to server capabilities, and I bet your product will hit some of those long before training becomes an issue.