Machine learning based on tensorflow.js

Machine learning based on tensorflow.js

What is machine learning? I think it's actually another kind of bells and whistles in statistics! tensorflow.js is a machine learning framework:

Develop ML with JavaScript, Use flexible and intuitive APIs to build and train models from scratch using the low-level JavaScript linear algebra library or the high-level layers API

We combine tensorflow.js and Baidu echarts to make a classic case of least squares, linear regression example:

<!DOCTYPE html>
<html style="height: 100%">
   <head><meta charset="utf-8"></head>
   <body style="height: 100%; margin: 0">
       <div id="container" style="height: 100%"></div>
       <script type="text/javascript" src=""></script>
       <script src=""></script>
       <script type="text/javascript">
 * @license modified from the official case, hereby explain.
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * ================================================ ============================

function generateData(numPoints, coeff, sigma = 0.04) {//Generate pseudo-random numbers
  return tf.tidy(() => {
    const [k, b] = [tf.scalar(coeff.k),tf.scalar(coeff.b)];

    const xs = tf.randomUniform([numPoints], -1, 1);//x coordinate
    const ys = k.mul(xs).add(b)//y coordinate
      .add(tf.randomNormal([numPoints], 0, sigma));//Superimpose noise

    return {xs, ys: ys};

//Step 1. Variables to be regressed
const k = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
//Step 2. Select the optimizer, number of iterations and other parameters
const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);
//Step 3. Prediction function, defined as a linear function y = k * x + b
function predict(x) {//y = k * x + b
  return tf.tidy(() => {return k.mul(x).add(b);});
//Step 4. Calculate the variance, the smaller the variance, the more accurate the predicted value
function loss(prediction, labels) {
  const error = prediction.sub(labels).square().mean();
  return error;
//Step 5. Training function
async function train(xs, ys, numIterations) {
  for (let iter = 0; iter <numIterations; iter++) {
   //Optimize and minimize variance
    optimizer.minimize(() => {
      const pred = predict(xs);//Predict the output value based on the input data
      return loss(pred, ys);//Calculate the variance between the predicted value and the training data

    await tf.nextFrame();//
//Machine learning
async function learnCoefficients() {//
  const trueCoefficients = {k: 0.6, b: 0.8};//true value
  const trainingData = generateData(100, trueCoefficients);//Data used for model training
  await train(trainingData.xs, trainingData.ys, numIterations);//Model training

   var xvals = await;//x coordinate value of training data
   var yvals = await;//y coordinate value of training data
   var sDatas = Array.from(yvals).map((y,i) => {return [xvals[i],yvals[i]]});//Organize the training data for drawing
  console.log("k&b:",k.dataSync()[0],b.dataSync()[0]);//Coefficient after training
//Draw the result
function showResult(scatterData,k,b){
  var dom = document.getElementById("container");
  var myChart = echarts.init(dom);
  function realFun(x){return 0.6*x+0.8;}//Ideal curve
  function factFun(x){return k*x+b;}//curve after regression
  var realData = [[-1,realFun(-1)],[1,realFun(1)]];
  var factData = [[-1,factFun(-1)],[1,factFun(1)]];

  var option = {
      title: {text:'Linear regression of data through machine learning', left:'left'},
      tooltip: {trigger:'axis',axisPointer: {type:'cross'}},
      xAxis: {type:'value',splitLine: {lineStyle: {type:'dashed'}},},
      yAxis: {type:'value',splitLine: {lineStyle: {type:'dashed'}}},
      series: [{
          name:'Discrete point',type:'scatter',
          label: {
              emphasis: {
                  show: true,
                  textStyle: {
                      fontSize: 16
          data: scatterData
      {name:'Ideal curve',type:'line',showSymbol: false,data: realData,},
      {name:'Regression curve',type:'line',showSymbol: false,data: factData,},],
      legend: {data:['discrete point','ideal curve','regression curve']},//Legend text

  myChart.setOption(option, true);


The results are as follows, the regression curve is very close to the actual situation:


[1] Deqing L, Honghui M, Yi S ,et al. ECharts: A declarative framework for rapid construction of web-basedvisualization[J]. Visual Informatics, 2018:S2468502X18300068-.

Reference: Machine learning based on tensorflow.js-Cloud + Community-Tencent Cloud