Check out a free preview of the full Machine Learning in JavaScript with TensorFlow.js course

The "Using the Model in the Browser" Lesson is part of the full, Machine Learning in JavaScript with TensorFlow.js course featured in this preview video. Here's what you'd learn in this lesson:

Charlie implements the frontend code to run the prediction code when the predict button is clicked. The label of the predicted shape is displayed on the screen. Some additional tips for improving accuracy are also shared.

Preview
Close

Transcript from the "Using the Model in the Browser" Lesson

[00:00:00]
>> Now if I go back to the code, so we have our file to get the data, our file to create the model, our file to do the training steps and we are saving it into a public, not a public folder, but to me it's like in the public folder.

[00:00:16]
And I created two, if you just remember in my model folder, I had one that had a very good accuracy all of a sudden, and then I just tweaked some parameters and saved it into a new folder so I could test things without ruining my original model. But now that we have created a custom machine learning model, well, we need to use it in the front end and see if it works well.

[00:00:37]
So, if we go to our index.js file, at the moment you should only have the clear button, event handler to clear what's on the Canvas. And what we need here in the front end, we need to actually import our model kinda like the same way that we did yesterday a little bit.

[00:00:55]
So, first of all, we are just going to declare a viable code model. And under we're going to specify the URL like the path to our model file, so we can do model path, and then that is in the model folder for me. If you've made changes and one of your model is more accurate, then maybe use the one that has a better accuracy for now.

[00:01:19]
And we are going to specifically use the model.json file. And then, what we can do is, we can just have a function to load the model, so called, const loadmodel = async function, we're gonna pass the, let's say, that path. We're gonna call it later, but for now, we're just declaring it.

[00:01:46]
And we're going to await tf, did I? No, yeah, okay. So, here I can call tf because, oops, tf.loadLayersModel(path) because here in the index.html, again, I require sensor just as like a script tag, but we could also import it. The way that you decide to do this, doesn't really matter here, but here it means that we have, access to the, tf object.

[00:02:19]
So, at this point, it's just a simple function that loads. You don't have to have that in a function, but sometimes if you do more things related to loading the model, then I just like to have smaller functions. And okay, so we have loaded our model, and now, what we want to be able to do is to predict, so if we look back at our, if I just rerun my app, and, if you remember here we have a predict button that for now does nothing.

[00:02:51]
So, we need to actually add also like a event handler, onto this side as well. So, below, like under you declare your clear ban, we can have a const predictButton = document.getElementById("""). I think in the HTML I called it check-button, so let's just use that. Probably would have been better to call it predict, but I guess we're checking the images.

[00:03:20]
Okay, so we have predictButton and now we need to have the onclick handler. So predictButton.onclick = (). Okay, so in our onclick function here on the predictButton, so the first thing we're gonna do is we're gonna get the Canvas here, so we're just going to return the Canvas element.

[00:03:42]
The getCanvas method come from the utl file, so oops. So you can just have, you can just import it here, and that returns the Canvas element, but sometimes the way that what needs to be fed to then to TensorFlow.js might be just an image element. Maybe not what's on in the canvas itself so what I the way that I've done it, that's maybe all the ways to do it but I create a variable, const drawing = anvas.toData URL.

[00:04:15]
And then we are going to actually use that as like the source attribute on the new image. So if we look at our index.html file, this is what I had here. When I have like an empty image and with a class that image to check or something, we're going to get that image that is currently empty.

[00:04:33]
I'm gonna call it const newImg = document.getElementsByClassName("imageToCheck"). Terrible name, I feel free to refactor it afterwards, and we're gonna get the first one. Again, I could have used ID, but class works fine, and we are going to set the source to our actual drawing. So, newImg.src = drawing.

[00:05:05]
Okay, and now when this image is loaded, then we're calling a predict on that. So we have newImg.onload, so when we're sure that the data that's from the canvas has been really used as a source for the image, then we have an onload function here to make sure that we're not calling predict with when the data on the image is not ready.

[00:05:24]
So, here we're calling a predict function that we have not written yet, but we will, and autocomplete is not helping me. And we are going to just call predict(newImg). And then because we're predicting, we can also resetCanvas() after. So we can draw a new shape and then predict again and things like that.

[00:05:47]
So at this point, we have our event handler, so when we click on the predict button, we're getting the canvas element, we're extracting the data and appending it to a new image element. And when that image element Has loaded, then we're feeding it to a predict function that we are going to right now.

[00:06:03]
So here we're declaring a const predict = async(Img), and it expects an image as parameter. And then, just to be sure, we need to set the width and height to 200 because that's the original size of our canvas. So if we do img.width = 200, and then the height is going to be the same img.width =200.

[00:06:37]
To do do do, okay, and then what we're gonna do next is we need to transform that image before we give it to the model. So I'm going to declare a variable called const processedImg = await tf.browser.fromPixelsAsync(Img,4); which we're passing the image and the number of channels. And here, RGBA, there was four channels.

[00:07:11]
And this is the same channels that we used when we created our model here, the input shape, because the input shape of the first layer is expecting a 28, 28, 4. So we did the first thing of transforming the data into intensive with channels of four, but we have not resized yet our image.

[00:07:31]
So we need to do that now before we actually feed it to the model. So, const resizedImg = tf.image.resizeNearestNeighbour.
>> Do you or I think.
>> Or you just or?
>> Just be or.
>> Okay just or and we need to proceed the processed image and then the shape that we actually trained our model or that we created a model with them so 28 and 28.

[00:08:16]
Okay, so what, another thing that I had as well, is we're gonna because we had an issue before with, like, some one hot tensors that were expecting, what was it, that was expecting data in a certain way? Actually, we changed it from int to 32, but when it comes to the image, I found out what works is to actually cast the type to a float 32.

[00:08:39]
So we can declare a variable to, I don't know how to call it, const updatedImg = tf.cast(resizedImg, "float32");. I'm just trying to make sure that here, something's going to work as I expect, but actually we can test that later and see if that's actually required or if by default is 432, so we might not actually need it.

[00:09:12]
But for now, let's leave it like this, and then we're gonna let shape here that we're not giving a value to yet, but then we are getting our predictions, so const predictions = await model.predict. And here, because when we trained our model, we had a model of like four, a shape of like four different values.

[00:09:44]
When I expanded the dimensions, we can see that later. But we're passing into it, we're gonna reshape [SOUND] tf.reshape(updatedImg (shape =1,28,28, 4)). So, the 28 and 28 is the size of our image, the 4 is the number of channels and the 1 here might not be needed if we remove the expandDims function that we called when we, here when we used expandDims on our training data.

[00:10:32]
And that was adding a dimension to our data. So here I think it was, if I don't have that, then it will complain about like, your input data has only three values and we're expecting four. So we can try it like this and then we might be able to test that.

[00:10:45]
Well, if we remove 1 and we remove expandDims, then will that work as well. But for now, let's leave it like this, cause at least under that this way it works. And we need to call .data on our predictions as well. And that's gonna extract the data that we can really like, extract our predictions on.

[00:11:03]
Okay, so it's returning the predictions, but now we need to extract the label. So const label = predictions.indexOf(Math.max(...predictions). And we can then call or give that label as a parameter to util function that I called a displayPrediction(label) and we can import that as well from the utils file.

[00:11:37]
So here, again, so we have display prediction, and don't forget to import it from the utils file. And here we could have just console logged the label, but I think it's a bit nicer if you see it on the screen. And what this function does is really just looking at our predictions element and just replacing the text content.

[00:11:58]
So really not much. Okay, so at this point, we have not yet called that load model function. So what we can do at the bottom, at the end loadModel(modelpath);. So, I think this should work, but let's try it out. Okay, so if I am in the UI here, and I have my browser tools open, just in case, and I draw a triangle.

[00:12:33]
>> That line needs to be model equals a weight tf load layer model.
>> Yes, you're right, in my solutions folder actually needed it differently. Yes, well yeah, I didn't give,
>> It in your notes as a condition like if
>> Yes, I mean, yeah, I do have a condition, I'm not sure it's like, okay, let's do just to make sure but that shouldn't necessarily be needed but yes you're right.

[00:13:01]
I didn't actually give a value to the model variable, so then we can't call predict on it. Okay so hopefully that should now work. Okay, so if I saved, refreshed, I'm gonna draw my triangle. Triangle [LAUGH] okay, let's try with like a circle. So I'm clearing it and then if I do, okay, that's a terrible circle, but square.

[00:13:29]
Achieve, wait, I know why, I know why. So I square because originally in my solutions folder when I did it, I called it square, so it's only because in the utils somewhere, I probably have a square region. Yes, okay, so in my labels here, I have square, but it's actually circle.

[00:13:42]
So technically, it still worked, so let's refresh. If I do a circle again. Circle, okay, so I was actually, okay. So here, what would be interesting is, obviously, I saved the model that had a 95% accuracy, so it's pretty good. So if we try maybe like another triangle, I don't know if that's going to, well, let's forget that happened.

[00:14:03]
Okay, [LAUGH] so the thing is, you know, 95% is still not a hundred, but let's try and. So the thing is, even with 95% accuracy, you can see that if I sometimes do a right, like, draw a shape, it will actually have another label because we only again drew 40 shapes for each of them.

[00:14:22]
So even if I was using the less accurate model that I have in my model 1 folder, it would be even worse. It's like it would only be accurate 50% of the time which is not which is not great. But hopefully, if it works for you, then we now have a custom model that we created with our own data, and we have then used it in the browser, and we can recognize images that we drew.

[00:14:45]
But obviously then, if you want to be maybe a bit cheeky and you're like, well, what if I draw a candle or something? Obviously it doesn't know what a candle is, so it will just look at, well, what's the closest thing that I know and it thinks it's a circle, okay?

[00:15:00]
And the thing is that then if you want to create more drinks that you want to predict, then you do have to give that data to the model when you train it, otherwise, it will not know. It will really just base that out of the data that we fed it and the way that we're resizing stuff, then it just knows a circle and a triangles, it doesn't really know anything else.

[00:15:19]
But yes, for like a very small machine learning model that can predict the drawings, then it's working, okay, depending on the accuracy that you had again maybe you right now when you're trying it, it's not as accurate. So you can try to change a few different parameters in the way that we created our model.

[00:15:39]
Again, you can have more images or add more layers, or test with the epochs batch size and things like that. There's not really a set rule of how to do it. I wish I could tell you, hey, if you just use this, this, this, then it will be perfect, but it's not really how it works.

Learn Straight from the Experts Who Shape the Modern Web

  • In-depth Courses
  • Industry Leading Experts
  • Learning Paths
  • Live Interactive Workshops
Get Unlimited Access Now