Skip to content

Random forest

This example shows how to train random forest and use it for prediction if our, virtual passenger could survive titanic catastrophe.

code

import {
  growRandomForest,
  statistics,
  impurity,
  getRandomForestPrediction,
  buildAlgorithmConfiguration,
  sampleDataSets
} from 'tree-garden';

// titanic data set is bundled with tree-garden
const { titanicSet } = sampleDataSets;


// let`s tweak configuration a bit
const config = buildAlgorithmConfiguration(titanicSet, {

  // as i know these attributes cannot have any impact on final outcome - decided to decrease
  // computation complexity and do not count these fields in
  excludedAttributes: ['name', 'ticket', 'embarked', 'cabin'],
  attributes: {
    pclass: {
      dataType: 'discrete' // i want to treat class of passenger as discrete value, not number
    }
  },

  // several hundreds of trees is optimal
  numberOfTrees: 100,

  // [impurity scoring function]
  getScoreForSplit: impurity.getInformationGainForSplit

});

// check cofig
console.log('config:\n', config);


// our favorite titanic passenger
const KateWinslet = {
  fare: 30.0458,
  name: 'Kate Winslet',
  embarked: 'C',
  age: 21,
  parch: 0,
  pclass: 3, // this time Kate was traveling in low cost style
  sex: 'female',
  ticket: '2687',
  sibsp: 1 // and she has sister aboard - or brother, in every case, Leonardo will have hard time...
};

// lets start with training...
const { trees, oobError } = growRandomForest(config, titanicSet);


// lets check some metrics of trained forest:
// [out of the bag error]
console.log(`Out of the bag error for our trained forrest: ${oobError} % correct classifications!`);

// How deep is Your..... forest...
const depths = trees.map((tree) => statistics.getTreeDepth(tree));
console.log(`Trees depth:\n\taverage: ${statistics.getArithmeticAverage(depths)}\n\tmedian:${statistics.getMedian(depths)}`);

// and finally what about Kate?
// [random forest prediction outcome]
const wouldSheSurvive = getRandomForestPrediction(KateWinslet, trees, config);
console.log('Would Kate survive on titanic? - ', wouldSheSurvive);

comments

[impurity scoring function]
Information gain is cheaper than information gain ratio and we can use it here as we do not have fields with huge number of classes involved like name, ticket or cabin information gain ratio penalizes these fields as they has always high purity and thus their information gain is always high.

[out of the bag error]
Out of bag error is error metrics calculated during training and is calculated just on samples that were not using for training of given tree.

Out of bag error is computationally cheaper than external cross validation and can be used for tuning parameters, measure data preprocessing effectivity etc.

[random forest prediction outcome]
We are passing array of tree roots into getRandomForestPrediction under the hood, sample is passed to each tree, and some majority voting function is applied to extract final result. See fields majorityVoting, mergeClassificationResults and mergeRegressionResults of algorithm config