1. Introduction to Deep Learning¶
Deep learning is a category of machine learning. Machine learning is a category of artificial intelligence. These notes are mostly about deep learning, thus the name of the book. Deep learning is the use of neural networks to classify and regress data (this is too narrow, but a good starting place). I am a chemical engineering professor though; writing an introduction to deep learning is a hopeless task for me. I found the introduction the from Ian Goodfellow’s book to be a good intro. If you’re more visually oriented, Grant Sanderson has made a short video series specifically about neural networks that give an applied introduction to the topic. DeepMind has a high-level video showing what can be accomplished with deep learning & AI. When people write “deep learning is really cool” in their research papers, they typically cite this Nature paper by Yann LeCun, Yoshua Bengio, and Geoffery Hinton. Zhang, Lipton, Li, and Smola have written a practical and example-driven online book that gives each example in Tensorflow, PyTorch, and MXNet.
The main advice I would give to beginners in deep learning are to focus less on the neurological inspired language (i.e., connections between neurons), and instead view deep learning as a series of linear algebra operations where many of the matrices are filled with adjustable parameters. There are of course a few non-linear functions (activations) here and there, but deep learning is essentially linear algebra operations specified via a graph (network kind) that vaguely looks like neurons in a brain.
1.1. Neural Networks¶
The deep in deep learning means we have many layers in our neural networks. What is a neural network? Without loss of generality, we can view neural networks as 2 components: (1) a non-linear function \(g(\cdot)\) which operates on our input features \(\mathbf{X}\) and outputs a new set of features \(\mathbf{H} = g(\mathbf{X})\) and (2) a linear model like we saw in our Machine Learning chapter. Our model equation for deep learning regression is:
One of the main discussion points in our ML chapters was how arcane and difficult it is to choose features. Here, we have replaced our features with a set of trainable features \(g(\vec{x})\) and then use the same linear model as before. So how do we design \(g(\vec{x})\)? That is the deep learning part. \(g(\vec{x})\) is a differentiable function we design composed of layers, which are themselves differentiable functions each with trainable weights (variables). Deep learning is a mature field and there is a set of standard layers, each with a different purpose. For example, convolution layers look at a fixed neighborhood around each element of an input tensor. Dropout layers randomly inactivate inputs as a form of regularization. The most commonly used and basic layer is the dense or fully-connected layer.
A dense layer is defined by two things: the desired output feature shape and the activation. The equation is:
where \(\mathbf{W}\) is a trainable \(D \times F\) matrix, where \(D\) is the input vector (\(\vec{x}\)) dimension and \(F\) is the output vector (\(\vec{h}\)) dimension, \(\vec{b}\) is a trainable \(F\) dimensional vector, and \(\sigma(\cdot)\) is the activation function. \(F\) is an example of a hyperparameter, it is not trainable but is a problem dependent choice. \(\sigma(\cdot)\) is another hyperparameter. In principle, any differentiable function that has a range of \((-\infty, \infty)\) can be used for activation. However, just a few activations have been empirically designed that balance computational cost and effectiveness. One example we’ve seen before is the sigmoid. Another is a hyperbolic tangent, which behaves similar (domain/range) to the sigmoid. The most commonly used activation is the rectified linear unit (ReLU), which is
1.1.1. Universal Approximation Theorem¶
One of the reasons that neural networks are a good choice at approximating unknown functions (\(f(\vec{x})\)) is that a neural network can approximate any function with a large enough network depth (number of layers) or width (size of hidden layers). To be more specific, any 1 dimensional function can be approximated by a depth 5 neural network with ReLU activation functions. The universal approximation theorem shows that neural networks are, in the limit of large depth or width, expressive enough to fit any function.
1.2. Frameworks¶
Deep learning has lots of “gotchas” – easy to make mistakes that make it difficult to implement things yourself. This is especially true with numerical stability, which only reveals itself when your model fails to learn. We will move to a bit of a more abstract software framework than JAX. I may update this, but for now we’ll use Keras.
1.3. Discussion¶
When it comes to introducing deep learning, I will be as terse as possible. There are good learning resources out there. You should use some of the reading above and tutorials put out by Keras (or PyTorch) to get familiar with the concepts of neural networks and learning.
1.4. Revisiting Solubity Model¶
We’ll see our first example of deep learning by revisiting the solubility dataset with a two layer dense neural network.
1.5. Running This Notebook¶
Click the above to launch this page as an interactive Google Colab. See details below on installing packages, either on your own environment or on Google Colab
Tip
To install packages, execute this code in a new cell
!pip install matplotlib numpy pandas seaborn tensorflow
The hidden cells below sets-up our imports and/or install necessary packages.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import tensorflow as tf
import numpy as np
np.random.seed(0)
import warnings
warnings.filterwarnings('ignore')
sns.set_context('notebook')
sns.set_style('dark', {'xtick.bottom':True, 'ytick.left':True, 'xtick.color': '#666666', 'ytick.color': '#666666',
'axes.edgecolor': '#666666', 'axes.linewidth': 0.8 })
color_cycle = ['#1BBC9B', '#F06060', '#5C4B51', '#F3B562', '#6e5687']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=color_cycle)
1.5.1. Load Data¶
We download the data and load it into a Pandas data frame and then standardize our features as before.
soldata = pd.read_csv('https://dataverse.harvard.edu/api/access/datafile/3407241?format=original&gbrecs=true')
features_start_at = list(soldata.columns).index('MolWt')
feature_names = soldata.columns[features_start_at:]
# standardize the features
soldata[feature_names] -= soldata[feature_names].mean()
soldata[feature_names] /= soldata[feature_names].std()
1.6. Prepare Data for Keras¶
The deep learning libraries simplify many common tasks, like splitting data and building layers. This code below builds our dataset from numpy arrays.
full_data = tf.data.Dataset.from_tensor_slices((soldata[feature_names].values, soldata['Solubility'].values))
N = len(soldata)
test_N = int(0.1 * N)
test_data = full_data.take(test_N).batch(16)
train_data = full_data.skip(test_N).batch(16)
Notice that we used skip
and take
to split our dataset into two pieces and then created batches of data.
1.7. Neural Network¶
Now we build our neural network model. In this case, our \(g(\vec{x}) = \sigma\left(\mathbf{W^0}\vec{x} + \vec{b}\right)\). We will call the fucntion \(g(\vec{x})\) a hidden layer. This is because we do not observe its output. Remember, the solubility will be \(y = \vec{w}g(\vec{x}) + b\). We’ll choose our activation, \(\sigma(\cdot)\), to be tanh and the output dimension of the hidden-layer to be 32. You can read more about this API here, however you should be able to understand the process from the function names and comments.
# our hidden layer
# We only need to define the output dimension - 32.
hidden_layer = tf.keras.layers.Dense(32, activation='tanh')
# Last layer - which we want to output one number
# the predicted solubility.
output_layer = tf.keras.layers.Dense(1)
# Now we put the layers into a sequential model
model = tf.keras.Sequential()
model.add(hidden_layer)
model.add(output_layer)
# our model is complete
# Try out our model on first few datapoints
model(soldata[feature_names].values[:3])
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.27844822],
[-0.30213237],
[-0.5311944 ]], dtype=float32)>
We can see our model predicting the solubility for 3 molecules above. There is a warning about how our Pandas data is using float64 (double precision floating point numbers) but our model is using float32 (single precision), which doesn’t matter that much. It warns us because we are technically throwing out a little bit of precision, but our solubility has much more variance than the difference between 32 and 64 bit precision floating point numbers.
At this point, we’ve defined how our model structure should work and it can be called on data. Now we need to train it! We prepare the model for training by calling compile
, which is where we define our optimization (flavor of stochastic gradient descent) and loss
model.compile(optimizer='SGD', loss='mean_squared_error')
Look back to the amount of work it took to previously set-up loss and optimization process! Now we can train our model
model.fit(train_data, epochs=50)
Epoch 1/50
1/562 [..............................] - ETA: 2:46 - loss: 9.4026
66/562 [==>...........................] - ETA: 0s - loss: 7.5053
134/562 [======>.......................] - ETA: 0s - loss: 5.7002
204/562 [=========>....................] - ETA: 0s - loss: 4.9885
268/562 [=============>................] - ETA: 0s - loss: 4.5889
340/562 [=================>............] - ETA: 0s - loss: 4.2268
406/562 [====================>.........] - ETA: 0s - loss: 3.9542
473/562 [========================>.....] - ETA: 0s - loss: 3.7228
537/562 [===========================>..] - ETA: 0s - loss: 3.5380
562/562 [==============================] - 1s 756us/step - loss: 3.4723
Epoch 2/50
1/562 [..............................] - ETA: 1s - loss: 1.0884
66/562 [==>...........................] - ETA: 0s - loss: 2.5217
132/562 [======>.......................] - ETA: 0s - loss: 2.3275
196/562 [=========>....................] - ETA: 0s - loss: 2.3339
268/562 [=============>................] - ETA: 0s - loss: 2.3433
337/562 [================>.............] - ETA: 0s - loss: 2.2930
401/562 [====================>.........] - ETA: 0s - loss: 2.2297
463/562 [=======================>......] - ETA: 0s - loss: 2.1659
527/562 [===========================>..] - ETA: 0s - loss: 2.1055
562/562 [==============================] - 0s 769us/step - loss: 2.0755
Epoch 3/50
1/562 [..............................] - ETA: 1s - loss: 1.1118
71/562 [==>...........................] - ETA: 0s - loss: 2.3399
139/562 [======>.......................] - ETA: 0s - loss: 2.1485
200/562 [=========>....................] - ETA: 0s - loss: 2.1851
263/562 [=============>................] - ETA: 0s - loss: 2.1990
326/562 [================>.............] - ETA: 0s - loss: 2.1612
390/562 [===================>..........] - ETA: 0s - loss: 2.1044
456/562 [=======================>......] - ETA: 0s - loss: 2.0409
518/562 [==========================>...] - ETA: 0s - loss: 1.9855
562/562 [==============================] - 0s 783us/step - loss: 1.9498
Epoch 4/50
1/562 [..............................] - ETA: 1s - loss: 1.0272
72/562 [==>...........................] - ETA: 0s - loss: 2.2343
144/562 [======>.......................] - ETA: 0s - loss: 2.0476
208/562 [==========>...................] - ETA: 0s - loss: 2.1019
272/562 [=============>................] - ETA: 0s - loss: 2.1056
343/562 [=================>............] - ETA: 0s - loss: 2.0589
416/562 [=====================>........] - ETA: 0s - loss: 1.9929
480/562 [========================>.....] - ETA: 0s - loss: 1.9345
543/562 [===========================>..] - ETA: 0s - loss: 1.8824
562/562 [==============================] - 0s 749us/step - loss: 1.8679
Epoch 5/50
1/562 [..............................] - ETA: 0s - loss: 0.9366
71/562 [==>...........................] - ETA: 0s - loss: 2.1650
142/562 [======>.......................] - ETA: 0s - loss: 1.9911
213/562 [==========>...................] - ETA: 0s - loss: 2.0421
280/562 [=============>................] - ETA: 0s - loss: 2.0375
346/562 [=================>............] - ETA: 0s - loss: 1.9927
412/562 [====================>.........] - ETA: 0s - loss: 1.9340
475/562 [========================>.....] - ETA: 0s - loss: 1.8777
539/562 [===========================>..] - ETA: 0s - loss: 1.8258
562/562 [==============================] - 0s 756us/step - loss: 1.8088
Epoch 6/50
1/562 [..............................] - ETA: 0s - loss: 0.8883
64/562 [==>...........................] - ETA: 0s - loss: 2.1192
129/562 [=====>........................] - ETA: 0s - loss: 1.9866
194/562 [=========>....................] - ETA: 0s - loss: 1.9947
259/562 [============>.................] - ETA: 0s - loss: 2.0112
324/562 [================>.............] - ETA: 0s - loss: 1.9744
396/562 [====================>.........] - ETA: 0s - loss: 1.9133
468/562 [=======================>......] - ETA: 0s - loss: 1.8489
535/562 [===========================>..] - ETA: 0s - loss: 1.7947
562/562 [==============================] - 0s 760us/step - loss: 1.7751
Epoch 7/50
1/562 [..............................] - ETA: 0s - loss: 0.8657
68/562 [==>...........................] - ETA: 0s - loss: 2.1036
131/562 [=====>........................] - ETA: 0s - loss: 1.9596
195/562 [=========>....................] - ETA: 0s - loss: 1.9726
259/562 [============>.................] - ETA: 0s - loss: 1.9870
323/562 [================>.............] - ETA: 0s - loss: 1.9508
388/562 [===================>..........] - ETA: 0s - loss: 1.8965
456/562 [=======================>......] - ETA: 0s - loss: 1.8357
520/562 [==========================>...] - ETA: 0s - loss: 1.7832
562/562 [==============================] - 0s 781us/step - loss: 1.7525
Epoch 8/50
1/562 [..............................] - ETA: 1s - loss: 0.8612
58/562 [==>...........................] - ETA: 0s - loss: 2.0503
124/562 [=====>........................] - ETA: 0s - loss: 1.9593
191/562 [=========>....................] - ETA: 0s - loss: 1.9502
262/562 [============>.................] - ETA: 0s - loss: 1.9667
330/562 [================>.............] - ETA: 0s - loss: 1.9264
395/562 [====================>.........] - ETA: 0s - loss: 1.8714
460/562 [=======================>......] - ETA: 0s - loss: 1.8140
525/562 [===========================>..] - ETA: 0s - loss: 1.7616
562/562 [==============================] - 0s 769us/step - loss: 1.7350
Epoch 9/50
1/562 [..............................] - ETA: 1s - loss: 0.8671
68/562 [==>...........................] - ETA: 0s - loss: 2.0682
134/562 [======>.......................] - ETA: 0s - loss: 1.9203
199/562 [=========>....................] - ETA: 0s - loss: 1.9424
266/562 [=============>................] - ETA: 0s - loss: 1.9503
332/562 [================>.............] - ETA: 0s - loss: 1.9101
398/562 [====================>.........] - ETA: 0s - loss: 1.8543
464/562 [=======================>......] - ETA: 0s - loss: 1.7966
529/562 [===========================>..] - ETA: 0s - loss: 1.7450
562/562 [==============================] - 0s 766us/step - loss: 1.7217
Epoch 10/50
1/562 [..............................] - ETA: 0s - loss: 0.8793
71/562 [==>...........................] - ETA: 0s - loss: 2.0583
142/562 [======>.......................] - ETA: 0s - loss: 1.8944
207/562 [==========>...................] - ETA: 0s - loss: 1.9361
270/562 [=============>................] - ETA: 0s - loss: 1.9365
337/562 [================>.............] - ETA: 0s - loss: 1.8943
404/562 [====================>.........] - ETA: 0s - loss: 1.8373
471/562 [========================>.....] - ETA: 0s - loss: 1.7796
536/562 [===========================>..] - ETA: 0s - loss: 1.7289
562/562 [==============================] - 0s 758us/step - loss: 1.7108
Epoch 11/50
1/562 [..............................] - ETA: 1s - loss: 0.8898
67/562 [==>...........................] - ETA: 0s - loss: 2.0467
133/562 [======>.......................] - ETA: 0s - loss: 1.9029
200/562 [=========>....................] - ETA: 0s - loss: 1.9235
262/562 [============>.................] - ETA: 0s - loss: 1.9314
327/562 [================>.............] - ETA: 0s - loss: 1.8936
388/562 [===================>..........] - ETA: 0s - loss: 1.8432
455/562 [=======================>......] - ETA: 0s - loss: 1.7849
519/562 [==========================>...] - ETA: 0s - loss: 1.7336
562/562 [==============================] - 0s 779us/step - loss: 1.7030
Epoch 12/50
1/562 [..............................] - ETA: 1s - loss: 0.8969
66/562 [==>...........................] - ETA: 0s - loss: 2.0379
131/562 [=====>........................] - ETA: 0s - loss: 1.9005
196/562 [=========>....................] - ETA: 0s - loss: 1.9142
261/562 [============>.................] - ETA: 0s - loss: 1.9252
326/562 [================>.............] - ETA: 0s - loss: 1.8877
391/562 [===================>..........] - ETA: 0s - loss: 1.8340
457/562 [=======================>......] - ETA: 0s - loss: 1.7766
523/562 [==========================>...] - ETA: 0s - loss: 1.7242
562/562 [==============================] - 0s 772us/step - loss: 1.6967
Epoch 13/50
1/562 [..............................] - ETA: 1s - loss: 0.9006
77/562 [===>..........................] - ETA: 0s - loss: 2.0347
150/562 [=======>......................] - ETA: 0s - loss: 1.8663
224/562 [==========>...................] - ETA: 0s - loss: 1.9229
297/562 [==============>...............] - ETA: 0s - loss: 1.9012
373/562 [==================>...........] - ETA: 0s - loss: 1.8435
446/562 [======================>.......] - ETA: 0s - loss: 1.7799
517/562 [==========================>...] - ETA: 0s - loss: 1.7227
562/562 [==============================] - 0s 692us/step - loss: 1.6908
Epoch 14/50
1/562 [..............................] - ETA: 0s - loss: 0.9014
72/562 [==>...........................] - ETA: 0s - loss: 2.0307
147/562 [======>.......................] - ETA: 0s - loss: 1.8626
219/562 [==========>...................] - ETA: 0s - loss: 1.9155
291/562 [==============>...............] - ETA: 0s - loss: 1.8983
358/562 [==================>...........] - ETA: 0s - loss: 1.8498
425/562 [=====================>........] - ETA: 0s - loss: 1.7914
491/562 [=========================>....] - ETA: 0s - loss: 1.7367
562/562 [==============================] - 0s 716us/step - loss: 1.6845
Epoch 15/50
1/562 [..............................] - ETA: 0s - loss: 0.9016
75/562 [===>..........................] - ETA: 0s - loss: 2.0234
145/562 [======>.......................] - ETA: 0s - loss: 1.8598
211/562 [==========>...................] - ETA: 0s - loss: 1.9064
278/562 [=============>................] - ETA: 0s - loss: 1.8991
345/562 [=================>............] - ETA: 0s - loss: 1.8540
410/562 [====================>.........] - ETA: 0s - loss: 1.7982
475/562 [========================>.....] - ETA: 0s - loss: 1.7433
542/562 [===========================>..] - ETA: 0s - loss: 1.6924
562/562 [==============================] - 0s 750us/step - loss: 1.6786
Epoch 16/50
1/562 [..............................] - ETA: 0s - loss: 0.9008
68/562 [==>...........................] - ETA: 0s - loss: 2.0153
140/562 [======>.......................] - ETA: 0s - loss: 1.8621
209/562 [==========>...................] - ETA: 0s - loss: 1.8991
278/562 [=============>................] - ETA: 0s - loss: 1.8924
346/562 [=================>............] - ETA: 0s - loss: 1.8465
411/562 [====================>.........] - ETA: 0s - loss: 1.7907
476/562 [========================>.....] - ETA: 0s - loss: 1.7360
546/562 [============================>.] - ETA: 0s - loss: 1.6835
562/562 [==============================] - 0s 743us/step - loss: 1.6724
Epoch 17/50
1/562 [..............................] - ETA: 1s - loss: 0.8991
65/562 [==>...........................] - ETA: 0s - loss: 2.0034
131/562 [=====>........................] - ETA: 0s - loss: 1.8718
198/562 [=========>....................] - ETA: 0s - loss: 1.8862
263/562 [=============>................] - ETA: 0s - loss: 1.8927
328/562 [================>.............] - ETA: 0s - loss: 1.8537
394/562 [====================>.........] - ETA: 0s - loss: 1.7992
460/562 [=======================>......] - ETA: 0s - loss: 1.7426
527/562 [===========================>..] - ETA: 0s - loss: 1.6904
562/562 [==============================] - 0s 770us/step - loss: 1.6663
Epoch 18/50
1/562 [..............................] - ETA: 1s - loss: 0.8976
69/562 [==>...........................] - ETA: 0s - loss: 2.0038
138/562 [======>.......................] - ETA: 0s - loss: 1.8550
203/562 [=========>....................] - ETA: 0s - loss: 1.8844
268/562 [=============>................] - ETA: 0s - loss: 1.8849
334/562 [================>.............] - ETA: 0s - loss: 1.8433
401/562 [====================>.........] - ETA: 0s - loss: 1.7872
468/562 [=======================>......] - ETA: 0s - loss: 1.7303
535/562 [===========================>..] - ETA: 0s - loss: 1.6791
562/562 [==============================] - 0s 760us/step - loss: 1.6607
Epoch 19/50
1/562 [..............................] - ETA: 0s - loss: 0.8958
68/562 [==>...........................] - ETA: 0s - loss: 1.9980
134/562 [======>.......................] - ETA: 0s - loss: 1.8573
203/562 [=========>....................] - ETA: 0s - loss: 1.8796
271/562 [=============>................] - ETA: 0s - loss: 1.8783
338/562 [=================>............] - ETA: 0s - loss: 1.8348
406/562 [====================>.........] - ETA: 0s - loss: 1.7773
473/562 [========================>.....] - ETA: 0s - loss: 1.7210
540/562 [===========================>..] - ETA: 0s - loss: 1.6705
562/562 [==============================] - 0s 750us/step - loss: 1.6555
Epoch 20/50
1/562 [..............................] - ETA: 1s - loss: 0.8931
67/562 [==>...........................] - ETA: 0s - loss: 1.9917
133/562 [======>.......................] - ETA: 0s - loss: 1.8546
202/562 [=========>....................] - ETA: 0s - loss: 1.8742
270/562 [=============>................] - ETA: 0s - loss: 1.8735
336/562 [================>.............] - ETA: 0s - loss: 1.8309
403/562 [====================>.........] - ETA: 0s - loss: 1.7746
450/562 [=======================>......] - ETA: 0s - loss: 1.7345
515/562 [==========================>...] - ETA: 0s - loss: 1.6831
562/562 [==============================] - 0s 785us/step - loss: 1.6504
Epoch 21/50
1/562 [..............................] - ETA: 0s - loss: 0.8897
68/562 [==>...........................] - ETA: 0s - loss: 1.9874
132/562 [======>.......................] - ETA: 0s - loss: 1.8514
197/562 [=========>....................] - ETA: 0s - loss: 1.8658
264/562 [=============>................] - ETA: 0s - loss: 1.8705
329/562 [================>.............] - ETA: 0s - loss: 1.8306
394/562 [====================>.........] - ETA: 0s - loss: 1.7768
459/562 [=======================>......] - ETA: 0s - loss: 1.7214
525/562 [===========================>..] - ETA: 0s - loss: 1.6704
562/562 [==============================] - 0s 771us/step - loss: 1.6450
Epoch 22/50
1/562 [..............................] - ETA: 1s - loss: 0.8856
67/562 [==>...........................] - ETA: 0s - loss: 1.9799
129/562 [=====>........................] - ETA: 0s - loss: 1.8517
201/562 [=========>....................] - ETA: 0s - loss: 1.8630
261/562 [============>.................] - ETA: 0s - loss: 1.8659
327/562 [================>.............] - ETA: 0s - loss: 1.8262
391/562 [===================>..........] - ETA: 0s - loss: 1.7736
457/562 [=======================>......] - ETA: 0s - loss: 1.7174
522/562 [==========================>...] - ETA: 0s - loss: 1.6670
562/562 [==============================] - 0s 777us/step - loss: 1.6396
Epoch 23/50
1/562 [..............................] - ETA: 0s - loss: 0.8809
73/562 [==>...........................] - ETA: 0s - loss: 1.9763
144/562 [======>.......................] - ETA: 0s - loss: 1.8208
196/562 [=========>....................] - ETA: 0s - loss: 1.8538
260/562 [============>.................] - ETA: 0s - loss: 1.8601
326/562 [================>.............] - ETA: 0s - loss: 1.8208
393/562 [===================>..........] - ETA: 0s - loss: 1.7658
455/562 [=======================>......] - ETA: 0s - loss: 1.7132
518/562 [==========================>...] - ETA: 0s - loss: 1.6643
562/562 [==============================] - 0s 781us/step - loss: 1.6340
Epoch 24/50
1/562 [..............................] - ETA: 0s - loss: 0.8758
65/562 [==>...........................] - ETA: 0s - loss: 1.9617
118/562 [=====>........................] - ETA: 0s - loss: 1.8693
183/562 [========>.....................] - ETA: 0s - loss: 1.8369
246/562 [============>.................] - ETA: 0s - loss: 1.8582
311/562 [===============>..............] - ETA: 0s - loss: 1.8251
355/562 [=================>............] - ETA: 0s - loss: 1.7920
427/562 [=====================>........] - ETA: 0s - loss: 1.7305
490/562 [=========================>....] - ETA: 0s - loss: 1.6797
555/562 [============================>.] - ETA: 0s - loss: 1.6333
562/562 [==============================] - 0s 821us/step - loss: 1.6283
Epoch 25/50
1/562 [..............................] - ETA: 0s - loss: 0.8710
71/562 [==>...........................] - ETA: 0s - loss: 1.9614
137/562 [======>.......................] - ETA: 0s - loss: 1.8194
202/562 [=========>....................] - ETA: 0s - loss: 1.8453
261/562 [============>.................] - ETA: 0s - loss: 1.8469
325/562 [================>.............] - ETA: 0s - loss: 1.8088
389/562 [===================>..........] - ETA: 0s - loss: 1.7569
453/562 [=======================>......] - ETA: 0s - loss: 1.7029
515/562 [==========================>...] - ETA: 0s - loss: 1.6548
562/562 [==============================] - 0s 785us/step - loss: 1.6226
Epoch 26/50
1/562 [..............................] - ETA: 0s - loss: 0.8669
67/562 [==>...........................] - ETA: 0s - loss: 1.9508
133/562 [======>.......................] - ETA: 0s - loss: 1.8194
197/562 [=========>....................] - ETA: 0s - loss: 1.8359
260/562 [============>.................] - ETA: 0s - loss: 1.8409
325/562 [================>.............] - ETA: 0s - loss: 1.8025
389/562 [===================>..........] - ETA: 0s - loss: 1.7507
453/562 [=======================>......] - ETA: 0s - loss: 1.6969
517/562 [==========================>...] - ETA: 0s - loss: 1.6476
562/562 [==============================] - 0s 782us/step - loss: 1.6170
Epoch 27/50
1/562 [..............................] - ETA: 1s - loss: 0.8635
71/562 [==>...........................] - ETA: 0s - loss: 1.9469
139/562 [======>.......................] - ETA: 0s - loss: 1.8037
202/562 [=========>....................] - ETA: 0s - loss: 1.8329
266/562 [=============>................] - ETA: 0s - loss: 1.8322
330/562 [================>.............] - ETA: 0s - loss: 1.7925
395/562 [====================>.........] - ETA: 0s - loss: 1.7395
455/562 [=======================>......] - ETA: 0s - loss: 1.6894
525/562 [===========================>..] - ETA: 0s - loss: 1.6362
562/562 [==============================] - 0s 767us/step - loss: 1.6115
Epoch 28/50
1/562 [..............................] - ETA: 1s - loss: 0.8607
71/562 [==>...........................] - ETA: 0s - loss: 1.9400
144/562 [======>.......................] - ETA: 0s - loss: 1.7899
209/562 [==========>...................] - ETA: 0s - loss: 1.8303
274/562 [=============>................] - ETA: 0s - loss: 1.8222
330/562 [================>.............] - ETA: 0s - loss: 1.7864
394/562 [====================>.........] - ETA: 0s - loss: 1.7344
458/562 [=======================>......] - ETA: 0s - loss: 1.6812
521/562 [==========================>...] - ETA: 0s - loss: 1.6335
562/562 [==============================] - 0s 779us/step - loss: 1.6060
Epoch 29/50
1/562 [..............................] - ETA: 0s - loss: 0.8574
67/562 [==>...........................] - ETA: 0s - loss: 1.9295
139/562 [======>.......................] - ETA: 0s - loss: 1.7912
210/562 [==========>...................] - ETA: 0s - loss: 1.8244
282/562 [==============>...............] - ETA: 0s - loss: 1.8117
349/562 [=================>............] - ETA: 0s - loss: 1.7658
415/562 [=====================>........] - ETA: 0s - loss: 1.7107
481/562 [========================>.....] - ETA: 0s - loss: 1.6578
545/562 [============================>.] - ETA: 0s - loss: 1.6118
562/562 [==============================] - 0s 744us/step - loss: 1.6006
Epoch 30/50
1/562 [..............................] - ETA: 1s - loss: 0.8535
66/562 [==>...........................] - ETA: 0s - loss: 1.9201
117/562 [=====>........................] - ETA: 0s - loss: 1.8325
189/562 [=========>....................] - ETA: 0s - loss: 1.8052
259/562 [============>.................] - ETA: 0s - loss: 1.8161
332/562 [================>.............] - ETA: 0s - loss: 1.7726
404/562 [====================>.........] - ETA: 0s - loss: 1.7141
470/562 [========================>.....] - ETA: 0s - loss: 1.6605
537/562 [===========================>..] - ETA: 0s - loss: 1.6115
562/562 [==============================] - 0s 752us/step - loss: 1.5952
Epoch 31/50
1/562 [..............................] - ETA: 0s - loss: 0.8495
68/562 [==>...........................] - ETA: 0s - loss: 1.9150
134/562 [======>.......................] - ETA: 0s - loss: 1.7854
202/562 [=========>....................] - ETA: 0s - loss: 1.8079
270/562 [=============>................] - ETA: 0s - loss: 1.8053
337/562 [================>.............] - ETA: 0s - loss: 1.7628
380/562 [===================>..........] - ETA: 0s - loss: 1.7284
455/562 [=======================>......] - ETA: 0s - loss: 1.6666
522/562 [==========================>...] - ETA: 0s - loss: 1.6165
562/562 [==============================] - 0s 777us/step - loss: 1.5900
Epoch 32/50
1/562 [..............................] - ETA: 1s - loss: 0.8458
71/562 [==>...........................] - ETA: 0s - loss: 1.9091
143/562 [======>.......................] - ETA: 0s - loss: 1.7655
206/562 [=========>....................] - ETA: 0s - loss: 1.8040
280/562 [=============>................] - ETA: 0s - loss: 1.7943
347/562 [=================>............] - ETA: 0s - loss: 1.7495
382/562 [===================>..........] - ETA: 0s - loss: 1.7212
448/562 [======================>.......] - ETA: 0s - loss: 1.6669
510/562 [==========================>...] - ETA: 0s - loss: 1.6199
562/562 [==============================] - 0s 788us/step - loss: 1.5851
Epoch 33/50
1/562 [..............................] - ETA: 1s - loss: 0.8420
68/562 [==>...........................] - ETA: 0s - loss: 1.8998
137/562 [======>.......................] - ETA: 0s - loss: 1.7682
203/562 [=========>....................] - ETA: 0s - loss: 1.7966
270/562 [=============>................] - ETA: 0s - loss: 1.7936
336/562 [================>.............] - ETA: 0s - loss: 1.7523
400/562 [====================>.........] - ETA: 0s - loss: 1.7009
467/562 [=======================>......] - ETA: 0s - loss: 1.6469
532/562 [===========================>..] - ETA: 0s - loss: 1.5997
562/562 [==============================] - 0s 758us/step - loss: 1.5803
Epoch 34/50
1/562 [..............................] - ETA: 1s - loss: 0.8380
72/562 [==>...........................] - ETA: 0s - loss: 1.8950
143/562 [======>.......................] - ETA: 0s - loss: 1.7538
210/562 [==========>...................] - ETA: 0s - loss: 1.7946
276/562 [=============>................] - ETA: 0s - loss: 1.7855
340/562 [=================>............] - ETA: 0s - loss: 1.7442
403/562 [====================>.........] - ETA: 0s - loss: 1.6935
468/562 [=======================>......] - ETA: 0s - loss: 1.6416
529/562 [===========================>..] - ETA: 0s - loss: 1.5973
562/562 [==============================] - 0s 775us/step - loss: 1.5760
Epoch 35/50
1/562 [..............................] - ETA: 1s - loss: 0.8341
66/562 [==>...........................] - ETA: 0s - loss: 1.8837
132/562 [======>.......................] - ETA: 0s - loss: 1.7647
195/562 [=========>....................] - ETA: 0s - loss: 1.7813
250/562 [============>.................] - ETA: 0s - loss: 1.7909
316/562 [===============>..............] - ETA: 0s - loss: 1.7567
382/562 [===================>..........] - ETA: 0s - loss: 1.7064
447/562 [======================>.......] - ETA: 0s - loss: 1.6536
514/562 [==========================>...] - ETA: 0s - loss: 1.6036
562/562 [==============================] - 0s 784us/step - loss: 1.5720
Epoch 36/50
1/562 [..............................] - ETA: 1s - loss: 0.8305
67/562 [==>...........................] - ETA: 0s - loss: 1.8795
131/562 [=====>........................] - ETA: 0s - loss: 1.7616
195/562 [=========>....................] - ETA: 0s - loss: 1.7768
262/562 [============>.................] - ETA: 0s - loss: 1.7825
333/562 [================>.............] - ETA: 0s - loss: 1.7405
401/562 [====================>.........] - ETA: 0s - loss: 1.6866
474/562 [========================>.....] - ETA: 0s - loss: 1.6290
541/562 [===========================>..] - ETA: 0s - loss: 1.5816
562/562 [==============================] - 0s 747us/step - loss: 1.5682
Epoch 37/50
1/562 [..............................] - ETA: 1s - loss: 0.8269
74/562 [==>...........................] - ETA: 0s - loss: 1.8773
147/562 [======>.......................] - ETA: 0s - loss: 1.7354
212/562 [==========>...................] - ETA: 0s - loss: 1.7819
279/562 [=============>................] - ETA: 0s - loss: 1.7707
345/562 [=================>............] - ETA: 0s - loss: 1.7276
413/562 [=====================>........] - ETA: 0s - loss: 1.6730
477/562 [========================>.....] - ETA: 0s - loss: 1.6231
544/562 [============================>.] - ETA: 0s - loss: 1.5762
562/562 [==============================] - 0s 747us/step - loss: 1.5647
Epoch 38/50
1/562 [..............................] - ETA: 1s - loss: 0.8233
66/562 [==>...........................] - ETA: 0s - loss: 1.8672
130/562 [=====>........................] - ETA: 0s - loss: 1.7547
194/562 [=========>....................] - ETA: 0s - loss: 1.7680
268/562 [=============>................] - ETA: 0s - loss: 1.7719
340/562 [=================>............] - ETA: 0s - loss: 1.7275
412/562 [====================>.........] - ETA: 0s - loss: 1.6701
477/562 [========================>.....] - ETA: 0s - loss: 1.6196
541/562 [===========================>..] - ETA: 0s - loss: 1.5747
562/562 [==============================] - 0s 748us/step - loss: 1.5613
Epoch 39/50
1/562 [..............................] - ETA: 1s - loss: 0.8195
69/562 [==>...........................] - ETA: 0s - loss: 1.8660
134/562 [======>.......................] - ETA: 0s - loss: 1.7444
202/562 [=========>....................] - ETA: 0s - loss: 1.7695
267/562 [=============>................] - ETA: 0s - loss: 1.7685
335/562 [================>.............] - ETA: 0s - loss: 1.7274
402/562 [====================>.........] - ETA: 0s - loss: 1.6747
469/562 [========================>.....] - ETA: 0s - loss: 1.6220
534/562 [===========================>..] - ETA: 0s - loss: 1.5758
562/562 [==============================] - 0s 761us/step - loss: 1.5581
Epoch 40/50
1/562 [..............................] - ETA: 0s - loss: 0.8157
69/562 [==>...........................] - ETA: 0s - loss: 1.8611
136/562 [======>.......................] - ETA: 0s - loss: 1.7377
190/562 [=========>....................] - ETA: 0s - loss: 1.7575
254/562 [============>.................] - ETA: 0s - loss: 1.7695
317/562 [===============>..............] - ETA: 0s - loss: 1.7363
378/562 [===================>..........] - ETA: 0s - loss: 1.6906
452/562 [=======================>......] - ETA: 0s - loss: 1.6315
527/562 [===========================>..] - ETA: 0s - loss: 1.5772
562/562 [==============================] - 0s 766us/step - loss: 1.5549
Epoch 41/50
1/562 [..............................] - ETA: 0s - loss: 0.8116
72/562 [==>...........................] - ETA: 0s - loss: 1.8578
138/562 [======>.......................] - ETA: 0s - loss: 1.7313
203/562 [=========>....................] - ETA: 0s - loss: 1.7627
269/562 [=============>................] - ETA: 0s - loss: 1.7603
333/562 [================>.............] - ETA: 0s - loss: 1.7216
385/562 [===================>..........] - ETA: 0s - loss: 1.6816
455/562 [=======================>......] - ETA: 0s - loss: 1.6259
521/562 [==========================>...] - ETA: 0s - loss: 1.5780
562/562 [==============================] - 0s 774us/step - loss: 1.5517
Epoch 42/50
1/562 [..............................] - ETA: 1s - loss: 0.8073
64/562 [==>...........................] - ETA: 0s - loss: 1.8438
130/562 [=====>........................] - ETA: 0s - loss: 1.7392
192/562 [=========>....................] - ETA: 0s - loss: 1.7519
259/562 [============>.................] - ETA: 0s - loss: 1.7607
324/562 [================>.............] - ETA: 0s - loss: 1.7245
389/562 [===================>..........] - ETA: 0s - loss: 1.6750
454/562 [=======================>......] - ETA: 0s - loss: 1.6235
516/562 [==========================>...] - ETA: 0s - loss: 1.5784
562/562 [==============================] - 0s 783us/step - loss: 1.5487
Epoch 43/50
1/562 [..............................] - ETA: 1s - loss: 0.8027
67/562 [==>...........................] - ETA: 0s - loss: 1.8453
131/562 [=====>........................] - ETA: 0s - loss: 1.7342
197/562 [=========>....................] - ETA: 0s - loss: 1.7522
262/562 [============>.................] - ETA: 0s - loss: 1.7562
328/562 [================>.............] - ETA: 0s - loss: 1.7184
393/562 [===================>..........] - ETA: 0s - loss: 1.6685
459/562 [=======================>......] - ETA: 0s - loss: 1.6165
521/562 [==========================>...] - ETA: 0s - loss: 1.5719
562/562 [==============================] - 0s 777us/step - loss: 1.5458
Epoch 44/50
1/562 [..............................] - ETA: 0s - loss: 0.7981
74/562 [==>...........................] - ETA: 0s - loss: 1.8451
147/562 [======>.......................] - ETA: 0s - loss: 1.7102
211/562 [==========>...................] - ETA: 0s - loss: 1.7564
275/562 [=============>................] - ETA: 0s - loss: 1.7473
331/562 [================>.............] - ETA: 0s - loss: 1.7130
394/562 [====================>.........] - ETA: 0s - loss: 1.6646
457/562 [=======================>......] - ETA: 0s - loss: 1.6150
518/562 [==========================>...] - ETA: 0s - loss: 1.5710
562/562 [==============================] - 0s 784us/step - loss: 1.5429
Epoch 45/50
1/562 [..............................] - ETA: 1s - loss: 0.7933
73/562 [==>...........................] - ETA: 0s - loss: 1.8415
141/562 [======>.......................] - ETA: 0s - loss: 1.7141
209/562 [==========>...................] - ETA: 0s - loss: 1.7525
281/562 [==============>...............] - ETA: 0s - loss: 1.7411
346/562 [=================>............] - ETA: 0s - loss: 1.6991
401/562 [====================>.........] - ETA: 0s - loss: 1.6559
466/562 [=======================>......] - ETA: 0s - loss: 1.6053
533/562 [===========================>..] - ETA: 0s - loss: 1.5582
562/562 [==============================] - 0s 756us/step - loss: 1.5400
Epoch 46/50
1/562 [..............................] - ETA: 1s - loss: 0.7886
65/562 [==>...........................] - ETA: 0s - loss: 1.8301
130/562 [=====>........................] - ETA: 0s - loss: 1.7266
197/562 [=========>....................] - ETA: 0s - loss: 1.7430
264/562 [=============>................] - ETA: 0s - loss: 1.7458
325/562 [================>.............] - ETA: 0s - loss: 1.7110
391/562 [===================>..........] - ETA: 0s - loss: 1.6609
454/562 [=======================>......] - ETA: 0s - loss: 1.6114
518/562 [==========================>...] - ETA: 0s - loss: 1.5653
562/562 [==============================] - 0s 779us/step - loss: 1.5373
Epoch 47/50
1/562 [..............................] - ETA: 1s - loss: 0.7841
73/562 [==>...........................] - ETA: 0s - loss: 1.8346
145/562 [======>.......................] - ETA: 0s - loss: 1.7036
213/562 [==========>...................] - ETA: 0s - loss: 1.7480
286/562 [==============>...............] - ETA: 0s - loss: 1.7324
352/562 [=================>............] - ETA: 0s - loss: 1.6887
418/562 [=====================>........] - ETA: 0s - loss: 1.6365
484/562 [========================>.....] - ETA: 0s - loss: 1.5866
546/562 [============================>.] - ETA: 0s - loss: 1.5447
562/562 [==============================] - 0s 743us/step - loss: 1.5346
Epoch 48/50
1/562 [..............................] - ETA: 1s - loss: 0.7801
74/562 [==>...........................] - ETA: 0s - loss: 1.8312
142/562 [======>.......................] - ETA: 0s - loss: 1.7045
199/562 [=========>....................] - ETA: 0s - loss: 1.7385
266/562 [=============>................] - ETA: 0s - loss: 1.7390
333/562 [================>.............] - ETA: 0s - loss: 1.6994
402/562 [====================>.........] - ETA: 0s - loss: 1.6463
468/562 [=======================>......] - ETA: 0s - loss: 1.5953
533/562 [===========================>..] - ETA: 0s - loss: 1.5500
562/562 [==============================] - 0s 761us/step - loss: 1.5319
Epoch 49/50
1/562 [..............................] - ETA: 1s - loss: 0.7768
68/562 [==>...........................] - ETA: 0s - loss: 1.8251
135/562 [======>.......................] - ETA: 0s - loss: 1.7110
200/562 [=========>....................] - ETA: 0s - loss: 1.7363
267/562 [=============>................] - ETA: 0s - loss: 1.7357
335/562 [================>.............] - ETA: 0s - loss: 1.6951
397/562 [====================>.........] - ETA: 0s - loss: 1.6475
466/562 [=======================>......] - ETA: 0s - loss: 1.5940
533/562 [===========================>..] - ETA: 0s - loss: 1.5473
562/562 [==============================] - 0s 759us/step - loss: 1.5293
Epoch 50/50
1/562 [..............................] - ETA: 1s - loss: 0.7741
73/562 [==>...........................] - ETA: 0s - loss: 1.8246
137/562 [======>.......................] - ETA: 0s - loss: 1.7056
202/562 [=========>....................] - ETA: 0s - loss: 1.7347
265/562 [=============>................] - ETA: 0s - loss: 1.7336
330/562 [================>.............] - ETA: 0s - loss: 1.6958
395/562 [====================>.........] - ETA: 0s - loss: 1.6463
460/562 [=======================>......] - ETA: 0s - loss: 1.5958
522/562 [==========================>...] - ETA: 0s - loss: 1.5519
562/562 [==============================] - 0s 775us/step - loss: 1.5267
<tensorflow.python.keras.callbacks.History at 0x7f6412441e80>
That was quite simple!
For reference, we got a loss about as low as 3 in our previous work. It was also much faster, thanks to the optimizations. Now let’s see how our model did on the test data
# get model predictions on test data and get labels
# squeeze to remove extra dimensions
yhat = np.squeeze(model.predict(test_data))
test_y = soldata['Solubility'].values[:test_N]
plt.plot(test_y, yhat, '.')
plt.plot(test_y, test_y, '-')
plt.xlabel('Measured Solubility $y$')
plt.ylabel('Predicted Solubility $\hat{y}$')
plt.text(min(test_y) + 1, max(test_y) - 2, f'correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}')
plt.text(min(test_y) + 1, max(test_y) - 3, f'loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}')
plt.show()

This performance is better than our simple linear model.
1.8. Chapter Summary¶
Deep learning is a category of machine learning that utilizes neural networks for classification and regression of data.
Neural networks are a series of operations with matrices of adjustable parameters.
A neural network transforms input features into a new set of features that can be subsequently used for regression or classification.
The most common layer is the dense layer. Each input element affects each output element. It is defined by the desired output feature shape and the activation function.
With enough layers or wide enough hidden layers, neural networks can approximate unknown functions.
Hidden layers are called such because we do not observe the output from one.
Using libraries such as TensorFlow, it becomes easy to split data into training and testing, but also to build layers in the neural network.
Building a neural network allows us to predict various properties of molecules, such as solubility.