Machine learning API using Scikit-Learn, Flask and Docker


In this tutorial, we are going to create a containerized machine learning application. It is a simple application but it can be used as a template to build a more serious one. I hope by the end of this post you will have a basic idea about the following cool topics and technologies:

  • Machine learning (of course without the math) using scikit-learn Python library
  • Expose a machine learning model through REST API using Flask micro web framework
  • Build a docker image to easily share and deploy the demo app

What is machine learning?

In software development, there are real world problems that may not have a well defined solution. For example, there is no specific algorithm to classify an email as valid or spam. Regardless of the programming language in use, there is no practical way to solve that problem programmatically, however machine learning techniques proved to be effective in tackling such problems. Machine learning is a complex topic beyond the scope of this short post. We are only going to provide a taste of what machine learning can do without discussing any mathematics. In short, machine learning uses historical data to build mathematical models that can be used to predict properties of unknown data.

How machine learning works?

Take spam filtering problem that we just mentioned. It is possible to build a machine learning model to detect spammy emails given a large set of sample emails that have been already labeled (spam or not). Typical machine learning flow goes as follows:

  • Given labeled data samples (ex. emails marked as valid or spam)
  • Construct a two dimensional data matrix
  • Features are the columns (ex. subject line, word frequency, etc)
  • Data points (X sub matrix) are the rows (i.e. each row represents a sample email)
  • Result column (Y vector) can be continuous (in case of regression) or discrete (in case of classification as in the spam filtering problem)
  • Machine learning can be supervised which means data has to be labeled (spam filtering is an example on that)
  • Or it can be unsupervised meaning there are no target values. The goal is to discover similar groups or clusters (spam filtering is not a clustering problem, an example on that is sorting articles based on topics)
  • Split sample data into a training set and a testing set to evaluate the model prediction performance
  • Use cross validation to improve the predictor's generalization
  • Use different estimators or models (Scikit-learn supports several machine learning models for example SVM, KNN, etc)
  • Optimize the model parameters

From a software engineering perspective, that is not our specialty. We keep that to data science engineers to figure out the best features, best model and optimize it for us. Instead, we are going to consume that model and use it for prediction ready out of the box.

What is Scikit-Learn?

Scikit-learn is an open source machine learning library in Python. It provides access to some of the most popular machine learning algorithms. For more information about it you can refer to the following link. Here is a short list of machine learning topics it supports

  • Data standardization (i.e. rescaling attributes so that the mean is 0 and the standard deviation is 1) and normalization (i.e. rescaling attributes to the range of 0 to 1). These techniques improve prediction performance
  • Supervised learning estimators (ex. linear regression, support vector machine, naive bayes, KNN)
  • Unsupervised learning (i.e. clustering)
  • Cross validation (to improve model generalization)
  • Feature extraction (ex. from text files and images) and selection (i.e. finding relevant features)

What is Flask?

Flask is a micro web framework written in Python. It does not require any particular tools or libraries. We chose Flask because it is light, flexible and easy to use. Only few lines of code are needed to launch a fully working RESTful web service. We are going to use Flask to expose our machine learning model via a REST web service end point.

What is Docker?

Docker is a container engine that utilizes Linux features such as namespaces and control groups. This enables us to automate application development, testing and deployment. You can think of docker images as lightweight virtual environments. Not like virtual machines, docker is fast and can run on any host that is compatible with Linux kernel. This tutorial requires Docker to be installed on your computer. I am using Mac, you can download Docker for Mac from here.

Python requirements

There are some required Python libraries we need. Follow the steps below:

  • Create a text file and name it as requirements.txt
  • Copy the text below into the file
  • This file is going to be used to download and install all the needed python libraries
  • For quick explanation take a look at the comments below:
# Multi-dimensional arrays and more
# Scientific computing and more
# Data manipulation and analysis
# Machine learning library
# Micro web framework

Flask Server

  • The code below uses Flask to create an API server
  • We are going to expose one end point for prediction
  • The code is commented, just skim through
  • Create a python file and name it as
# Web server
from flask import Flask
# Get request parameters
from flask import request
# This is needed for logistic regression
from sklearn import linear_model
# Save and load models to/from disk
import pickle

# Output is the probability that the given
# input (ex. email) belongs to a certain
# class (ex. spam or not)
logReg = linear_model.LogisticRegression()

# Samples (your features, they should be normalized
# and standardized). Normalization scales the values
# into a range of [0,1]. Standardization scales data
# to have a mean of 0 and standard deviation of 1
# Note that we are using fake data here just to
# demonstrate the concept
X = [[1.0,1.0,2.1], [2.0,2.2,3.3], [3.0,1.1,3.0]]

# Labeled data (Spam or not)
Y = [1,0,1]

# Build the model, Y)

# Save it to disk
pickle.dump(logReg, open('logReg.pkl', 'wb'))

# API server
app = Flask(__name__)

# Define end point
@app.route('/demo/api/v1.0/predict', methods=['GET'])
def get_prediction():

	# We are using 3 features. For example:
	# subject line, word frequency, etc
	param1 = float(request.args.get('p1'))
	param2 = float(request.args.get('p2'))
	param3 = float(request.args.get('p3'))

	# Load model from disk
	logReg = pickle.load(open('logReg.pkl', 'rb'))

	# Predict
	pred = logReg.predict([[param1, param2, param3]])[0]
	if pred == 0:
		return "Email is spam"
		return "Email is valid"

# Main app
if __name__ == '__main__':,host='')

Docker File

  • Create a text file and name it as Dockerfile
  • Copy the text below to the docker file
  • You may need to check the docker website to understand the docker file syntax
  • Basically, we are creating a docker image based on the standard Python docker image. Add our application and its dependencies. Define the ports it is going to listen on. Build the image.
# Python base image
FROM python:3.6.3

# Where the API server lives

# Install required dependencies
COPY requirements.txt /app/
RUN pip install -r ./requirements.txt

# API server
COPY /app/
# Container port on which the server will be listening
# Launch server app
ENTRYPOINT python ./


To build the docker image and launch the container, follow the steps below:

# Build docker image, tag it as demo
docker build . -t demo

# Launch container
# map container port of 7777 to local host port of 4444
docker run -p 4444:7777 -d demo

# Open browser, hit the end point

If you wish, you can easily get the code by cloning the following github repo then apply the commands above build and run the container

git clone

Thanks for reading. For questions and feedback please use the comments section below.

Leave a Reply