diff --git a/data_science/ensemble/randomforest.ipynb b/data_science/ensemble/randomforest.ipynb new file mode 100755 index 0000000..662d955 --- /dev/null +++ b/data_science/ensemble/randomforest.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import datasets\n", + "from sklearn import tree\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import cross_val_score\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load MNIST dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "mnist = datasets.load_digits()\n", + "features, labels = mnist.data, mnist.target" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cross Validation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def cross_validation(classifier,features, labels):\n", + " cv_scores = []\n", + "\n", + " for i in range(10):\n", + " scores = cross_val_score(classifier, features, labels, cv=10, scoring='accuracy')\n", + " cv_scores.append(scores.mean())\n", + " \n", + " return cv_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "dt_cv_scores = cross_validation(tree.DecisionTreeClassifier(), features, labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "rf_cv_scores = cross_validation(RandomForestClassifier(), features, labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random Forest VS Decision Tree visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "cv_list = [ \n", + " ['random_forest',rf_cv_scores],\n", + " ['decision_tree',dt_cv_scores],\n", + " ]\n", + "df = pd.DataFrame.from_items(cv_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xd8VFX+//HXSSchkEACKCEkNOmEEIp0RaXo2lnERQULi2VFf3Z3VWTX9l3WFVcW1wJ2sWBBBQuIIoJIIKG3EEISSholjZD2+f1xJ8kklISQMJPcz/PxyCMz956Z+cwQ3vfec889Y0QEpZRS9uDh6gKUUkqdOxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillIxr6SillI16uLqCqkJAQiYiIcHUZSinVoKxbty5TREKra+d2oR8REUFsbKyry1BKqQbFGLO3Ju20e0cppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWxEQ18ppWzE7cbpq7MjIqQePsb65MPkF5ZwXXQYPl66bVdKWTT0G7j8wmI2ph5lffJh4pKPEJd8hMzc4+XrP1iTzMsT+xIZEuDCKpVS7kJDvwEREZKy8lm/9zBxKVbIbz+YQ0mp9eX2HUICGNEllL7hQfQNDyLlUD6PLNzEFS//wt+v7sm10WEufgdKKVfT0HdjOQVF1l783sPEpRwhLvkwh/OLAGjq60VUuyDuHtmRvuHBRLULIjjAp9Lje5zfnN5hQdz3UTz/7+MN/LIrk5lX9SDQz9sVb0cp5QYaTeiXlgq7M3IJDfSleRNvjDGuLumMlJYKiZm5rN97hLiUw6zfe4Sd6TmItRNP51ZNuax7G8defDCdWjXF06P693h+UBM+vGMQc5Yn8NLSnaxPPszLN/SlT7ugen5HSil3ZKQsVdxETEyM1GbCtUN5hUT//QcAvD0NLQN8CQ30JaSpj+O39VN2OzTQl9CmvjRr4uWSDcTR/KLyLpq4lCPEJx8mu6AYgGZ+XvQNDyY6PJi+4UH0aRdE8yZnv3cem3SI6QviScsu4MHRFzB1WAc8arDhUEq5P2PMOhGJqbZdYwn9/MJiftiaRmZuIRk5x8nMtX4qbheW93078/H0IKSpDyGOjUBIU19CAn2s22XLHBuKZn6120CUlAq70nOsvfjkw6xPPszujDwAPAx0aR1IdPtg+rYLIrp9MJEtA+otjI/mF/HY5xtZvOkgQzuF8OIf+9CqmV+9vJZS6tyxXehXp7RUOHKsqNKGICPnOBm5x8nMKXT8tpZn5Z1iA+Hl4dgw+FQ6Yqh8FOGDv48XW/YfJS75COuTD7Mh5Qh5hSUAtAjwKQ/3vu2C6N0uiKa+57aXTUT4aG0KM77aQoCPF7PG9+Girq3OaQ1KqbqloX8WSkuFw/mF5RuEShuK8tvWEcWhvOOcZPsAgKeHodt5geXdNNHhwYS38Heb8w0J6Tnc80Ec2w/mcNvQSB4ecwG+Xp6uLkspVQsa+udISdkGwqlLKaegmAtaB9I7LIgmPu4dogVFJTy/ZDtvrUqix/nNeHliXzqGNnV1WUqpM6Shr87I0q1pPPTpBgqKSnn6qh6M7xfmNkckSqnq1TT09fp8BcAl3VuzZPpwotoF8fCnG7l3QTzZBUWuLkspVcc09FW5Ns39eO/2gTw0+gIWbzrAuNm/sD75sKvLUkrVIQ19VYmnh+HuizrxybQLARj/6mrmLE846WgmdWYSM3KZsWgLK3dluroUZWPap69OKbugiMc/28TXGw8wuGNL/j0hitY6pv+MJaTn8sqPu1i0YT+lAl4ehlnj+3B137auLk01Itqnr85aMz9v/jOxL/93fW/iko8w5qUVLNuW5uqyGoyE9BymL4jj0n//zHdb0rh9WAeWPziS/hEtuO+jeN5cucfVJSob0j19VSO7M3L5ywdxbD2QzeTBETw6tit+3u49HNVVdqXl8PKPCXy9cT9NvD256cL23DGsAyFNfQFrmOz9H8WzZPNBpo3oyCNjLtCRUuqs6ZBNVeeOF5fwwpIdzPt1D13bBPLKjX3p1CrQ1WW5jR0Hc3j5x10s3nQAf29Pbh4cwR3DOtCiyuynYF3f8eSXm3l/TTLj+4Xx3LW98PLUA29Vexr6qt4s357Og59sIK+wmBl/6MGE/u1svae6/WA2Ly/bxeJNB2nq68Utg9tz+9AOJ0x1XZWIMHvZLl5auotRXVvxyo3Rbn8xn3JfGvqqXqVnF/D/Pt7AyoRMLu91Hs9e26tOZgJtSLbut8L+2y1W2E8ZEsFtQyMJ8j992Ff13m97eeLLzfQLD+aNW2LO+PFKgYa+OgdKS4XXfklk1nc7aN3Mj5cnRtGvfQtXl1XvNu87ysvLdvH91jQCfb2YMjSS24ZE0ty/9hu9xZsOcN+CeCJC/Hn71gGc17xJHVas7EBDX50z8SlHuPfDOPYdOcb0UZ25+6JONfqCl4Zm876jvLR0F0u3pRHo58WtQyK59SzD3tmq3ZlMfWcdzZt48/atA+jUqnHPgSQiHM4vOuk5D3XmNPTVOZVTUMQTX2zmi/j9DIxswUs3RDWavdWNqUeYvXQXy7an08zPi9uGdmDykIh66c7avO8ok+evpaS0lHmT+9M3PLjOX8MdHDh6jMc/28TyHRlc07ctD4+5oNH8vbhKnYa+MWYMMBvwBN4QkeerrG8PzANCgUPAJBFJNcZEAXOBZkAJ8IyIfHS619LQb9g+W5/K377YjI+XBy9c15vRPdq4uqRa25ByhNnLdvHj9nSaN/Hm9qGR3DIkgmb1/B3De7PyuOnN38nIOc7cSdGMvKDxfNeBiLBgbQrPfrON4lJhbM82fL3pAB4Gpg7vyLQRHfD3aTTf4npO1VnoG2M8gZ3ApUAqsBaYKCJbndp8AnwtIm8bYy4GpojITcaYLoCIyC5jzPnAOqCbiBw51etp6Dd8ezLzuPfDODbtO8pNg9rz18u7Nagx/XHJh5m9bBc/7cggyN+bO4Z14OYL25/TL5TPyDnO5Pm/s+NgTqO5ejflUD6PfraRXxOyuLBDS164rjfhLf1JOZTPC99u5+uNB2jdzJeHR3flmr5tbfdVnkePFZGWXUCX1rUbBl2XoX8hMENERjvuPwYgIs85tdkCjHbs3RvgqIg0O8lzbQCuF5Fdp3o9Df3GobC4lFnf7+C1FYkE+XvTpVUgkSEBdAgNcPxuSngLf3y83Gds+rq9Vtiv2JlBsL83dwzvwM0XRpzzbzYrk1NQxNR31rE6MYu/Xd6N24d1cEkdZ6u0VHhndRIvfLsDTw/D4+O6MXHAicN81+09xMyvt7Eh5Qi9w5rzxBXd6R/R+AcGZOUeZ96ve3hn1V7aBjdhyfRhtRoCXZehfz0wRkRud9y/CRgoIvc4tfkAWCMis40x1wILgRARyXJqMwB4G+ghIqWnej0N/cbl14RMvozfx57MPPZk5pGZW1i+zsNAuxb+dAgJIDKkKZGhAXQMCSAyNIA2zfzO2dj/2KRDzF62i192ZdIiwIepwztw06D2BLgo7J0VFJXw/z6OZ/Gmg/x5RAceHdO1QV0TkZiRyyMLN7I26TAjuoTy7LW9aBt06r770lLhyw37eGHJDg5mFzCuVxseG9uNdi38z2HV58aBo8d4bUUiH/6ezPHiUsb1PI87R3akZ9vmtXq+moZ+Tf6qT/YXVnVL8SDwijFmMrAC2AcUOxVzHvAucMvJAt8YMxWYChAeHl6DklRDMaRTCEM6hZTfP3qsyLEByCUxI4/EzDz2ZOTxW+IhjhWVlLdr4u1JpGMDULYhiAxpSofQgDrrU/99zyFmL9vJrwlZhDT14fFxXZk0qL1b9Sn7eXvyn4nRtAjYzP9+TiQrt5DnG8DVuyWlwpsrE/nX9zvx9fJg1vg+XBfdttoNloeH4Zq+YYzpcR6vrUjk1Z93s3RrOlOGRnDPRZ3OaRdbfUnKzOPVn3ezcH0qpQJXR7XlzpEdz9lorTrp3qnSvimwXUTCHPebAT8Bz4nIJ9UVpHv69lRaKqTlFLAnI4/djg3BnsxcEjPzSDmUX+l7iEOa+tAhpGn5RqGDo9sovEVAjbqLfkvMYvbSXaxOtML+z8M78qdB4W4V9lU1pKt3d6bl8NCnG9mQcoRLu7fmmat70qqWs7MePFrAP7/bwcL1qbQM8OGByy5gQv92DXJI8I6DOcxZbs3J5OXpwYSYdkwd3qHOjmLqsnvHC+tE7iisPfi1wI0issWpTQhwSERKjTHPACUi8qQxxgdYAnwlIi/VpHANfVVVYXEpyYfySczIZU9mHokZVldRYmYembnHy9uVdRdFhgRYGwWnDUKbZn78lniIl5buZM2eQ4QG+vLn4R3408D2bhueJ1N29W50eDBvutnVu0Ulpfzv5928vCyBpn5ePH1lD67ofV6ddEdtTD3C37/eytqkw3RtE8jfLu/O0M4h1T/QDcSnHGHO8gR+2JpGgI8nkwa157ZhkbQKrNtpyut6yOY44CWsIZvzROQZY8xMIFZEFjn6/Z/D6vZZAdwtIseNMZOA+cAWp6ebLCLxp3otDX11Jo4eKyLJcb4gMcM6Mig7f5BfWNFd5OPlQWFxKa0CfZk2oiM3DgxvUCOKnC3ZdIDpbnb17pb9R3nok41sPZDNFb3P4+kre9DSMatoXRERlmw+yHNLtpFy6Bijurbi8cu70THU/S5iExFWJ2bx3+W7WZmQSfMm3kwZEsHkwRH1tqHWi7OUrYkIadnHSXScO9iTmUdES3/Gx7RrsGHvrOzq3WZ+Xrxz2wCXzXZ6vLiEV35MYO5Puwny9+EfV/dkTM/6vTajoKiEt1Yl8cqPCRQUlTBpUHvuu6SzWxz1iAg/bk9nzvIE1icfITTQlzuGRXLjwPb1PgpMQ1+pRq7s6t3i0lLmu+Dq3fiUIzz86QZ2puVyXXQYT1zR7ZwGb0bOcV78YScfrU0m0M+b+y7pzKRB7fF2wUnuklJhyeYDzFm+m20Hsmkb1IRpIzsyvl/YOdvJ0NBXygb2ZuVx87zfSc8+d1fvFhSV8O8fdvL6L4m0bubHs9f04qKurrtqePvBbP7x9TZWJmTSITSAv47rxsVdW52Toa2FxaV8Eb+PV3/aTWJmHh1CA7hrZCeuijr/nG98NPSVsgnnq3f/Ob431/QNq7fXWpt0iIc/3ciezDwmDgjnsXFd631aipoo61Z55pttJGbmMbRTCH+7ohtd25xwjWidKCgq4aO1Kfzv593sP1pAj/ObcfdFnRjdo43LRhZp6CtlIzkFRfz53XWs2l0/V+/mHS/mn9/t4O3VSbQNasIL1/WudP2FuygqKeW93/by0tJd5BQUMaF/OA9c1qX8qyrPVk5BEe/9lsybKxPJzC0kpn0wd1/ciZFdQl1+0ZyGvlI2c7zY+u7dur5699eETB5ZuJHUw8eYPDiCh0Zf4BZXK5/OkfxCXlq6i/d+24uftyf3XNyJKUMi8PWqXf/64bxC5v+6h7dWJZFdUMywziHcc1EnBkS2cHnYl9HQV8qGSkqFGYu28O5ve7kuOoznr+tV677l7IIinlu8nQ9/TyYyJIAXruvNgMiGNRfO7oxcnv1mG8u2p9OuRRMeG9uNsT3b1Dio07ILeH1FIh/8nkx+YQmje7TmrpGd6NMuqJ4rP3Ma+krZlIjw8rIE/r10Jxd3bcWcWly9u3xHOo9/tom07ALuGNaB+y/t0qCHuq7clcnfv97KjrQcBkS04IkrutMr7NRz3KQcyufVn3fzSWwqJSJc2ed87hzZsdYzYJ4LGvpK2dz7a/byxBeb6XsGV+8eyS9k5tdb+Wz9Pjq3aso/x/chyg33amujuKSUj2JTePH7nWTlFXJtdFseHt2VNs0rrozdlZbD3J928+WG/Xgaw3X9wrhzREfCW7r/hG8a+kopvt18gHs/jKd9S3/eue30V+9+t+Ugf/tiM4fyCrlrZEfuubhTrfvA3Vl2QRFzlicwf2USnh6GP4/owLDOIby+Yg/fbT2In5cnNw4M545hHSptENydhr5SCoDVu7OY+k4sgae4ejcr9zhPLdrC1xsP0P28Zvzf9b1rPb1vQ5Kclc/z325j8aaDAAT6eTF5cARThkQ2yO/t1dBXSpXbsv8ot8yrfPWuiPDVxgPMWLSFnIIipo/qzJ9HdHTJFa2uFJt0iJ1puVzR5zy3uOagtjT0lVKVJGflc9O8NaRnH+fZa3uyeNNBftiaRp+w5vxzfB+3Pkmpqqehr5Q6QdnVu1v2Z+Pr5cEDl3Xh1iGRbv+lLKp6dfnNWUqpRiI00JcFUwfx9qokxvU6jw5uOC2xql8a+krZTKCfN/dc3NnVZSgX0WM6pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSykRqFvjFmjDFmhzEmwRjz6EnWtzfGLDPGbDTG/GSMCXNad4sxZpfj55a6LF4ppdSZqTb0jTGewBxgLNAdmGiM6V6l2SzgHRHpDcwEnnM8tgXwFDAQGAA8ZYwJrrvylVJKnYma7OkPABJEJFFECoEFwFVV2nQHljluL3daPxr4QUQOichh4AdgzNmXrZRSqjZqEvptgRSn+6mOZc42ANc5bl8DBBpjWtbwsUoppc6RmoS+Ocmyqt+m/iAwwhgTB4wA9gHFNXwsxpipxphYY0xsRkZGDUpSSilVGzUJ/VSgndP9MGC/cwMR2S8i14pIX+CvjmVHa/JYR9vXRCRGRGJCQ0PP8C0opZSqqZqE/lqgszEm0hjjA9wALHJuYIwJMcaUPddjwDzH7e+Ay4wxwY4TuJc5limllHKBakNfRIqBe7DCehvwsYhsMcbMNMZc6Wg2EthhjNkJtAaecTz2EPB3rA3HWmCmY5lSSikXMCIndLG7VExMjMTGxrq6DKWUalCMMetEJKa6dnpFrlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiGvlJK2YiXqwtQSrleUVERqampFBQUuLoUVQ0/Pz/CwsLw9vau1eM19JVSpKamEhgYSEREBMac7FtOlTsQEbKyskhNTSUyMrJWz6HdO0opCgoKaNmypQa+mzPG0LJly7M6ItPQV0oBaOA3EGf776Shr5RSNqKhr5RqdCIiIsjMzKyX5z5+/DiXXHIJUVFRfPTRR/XyGvHx8SxevLhenltP5Cql3IqIICJ4eLjnPmlcXBxFRUXEx8fX+DElJSV4enrWuH18fDyxsbGMGzeuNiWeloa+UqqSp7/awtb92XX6nN3Pb8ZTf+hxyvVJSUmMHTuWiy66iNWrVxMVFcWmTZs4duwY119/PU8//TRg7cHfcsstfPXVVxQVFfHJJ5/QtWtXsrKymDhxIhkZGQwYMAARKX/uF198kXnz5gFw++23c99995GUlMSYMWMYOnQov/32G3369GHKlCk89dRTpKen8/777zNgwIAT6kxPT2fSpElkZGQQFRXFwoULSUpK4sEHH6S4uJj+/fszd+5cfH19iYiI4NZbb+X777/nnnvuoX///tx9991kZGTg7+/P66+/TteuXfnkk094+umn8fT0pHnz5ixdupQnn3ySY8eOsXLlSh577DEmTJhQZ/8W7rkpVUrZzo4dO7j55puJi4vjX//6F7GxsWzcuJGff/6ZjRs3lrcLCQlh/fr13HnnncyaNQuAp59+mqFDhxIXF8eVV15JcnIyAOvWrWP+/PmsWbOG3377jddff524uDgAEhISmD59Ohs3bmT79u188MEHrFy5klmzZvHss8+etMZWrVrxxhtvMGzYMOLj42nbti2TJ0/mo48+YtOmTRQXFzN37tzy9n5+fqxcuZIbbriBqVOn8p///Id169Yxa9Ys7rrrLgBmzpzJd999x4YNG1i0aBE+Pj7MnDmTCRMmEB8fX6eBD7qnr5Sq4nR75PWpffv2DBo0CICPP/6Y1157jeLiYg4cOMDWrVvp3bs3ANdeey0A/fr147PPPgNgxYoV5bcvv/xygoODAVi5ciXXXHMNAQEB5Y/95ZdfuPLKK4mMjKRXr14A9OjRg1GjRmGMoVevXiQlJdWo5h07dhAZGUmXLl0AuOWWW5gzZw733XcfQHlg5+bmsmrVKsaPH1/+2OPHjwMwZMgQJk+ezB//+Mfy91afNPSVUm6hLJj37NnDrFmzWLt2LcHBwUyePLnSuHRfX18APD09KS4uLl9+sqGMzt08VZU9D4CHh0f5fQ8Pj0rPezqne36oeE+lpaUEBQWd9DzAq6++ypo1a/jmm2+Iioo6o3MFtaHdO0opt5KdnU1AQADNmzcnLS2NJUuWVPuY4cOH8/777wOwZMkSDh8+XL78iy++ID8/n7y8PD7//HOGDRtWZ7V27dqVpKQkEhISAHj33XcZMWLECe2aNWtGZGQkn3zyCWBtLDZs2ADA7t27GThwIDNnziQkJISUlBQCAwPJycmpszqdaegrpdxKnz596Nu3Lz169ODWW29lyJAh1T7mqaeeYsWKFURHR/P9998THh4OQHR0NJMnT2bAgAEMHDiQ22+/nb59+9ZZrX5+fsyfP5/x48fTq1cvPDw8mDZt2knbvv/++7z55pv06dOHHj168OWXXwLw0EMP0atXL3r27Mnw4cPp06cPF110EVu3bq2XYaGmusOTcy0mJkZiY2NdXYZStrJt2za6devm6jJUDZ3s38sYs05EYqp7rO7pK6WUjeiJXKWUOon58+cze/bsSsuGDBnCnDlzXFRR3ahR6BtjxgCzAU/gDRF5vsr6cOBtIMjR5lERWWyM8QbeAKIdr/WOiDxXh/UrpVS9mDJlClOmTHF1GXWu2u4dY4wnMAcYC3QHJhpjuldp9jfgYxHpC9wA/NexfDzgKyK9gH7An40xEXVTulJKqTNVkz79AUCCiCSKSCGwALiqShsBmjluNwf2Oy0PMMZ4AU2AQqBur+9WSilVYzUJ/bZAitP9VMcyZzOAScaYVGAx8BfH8k+BPOAAkAzMEpFDZ1OwUkqp2qtJ6J9sxv6q4zwnAm+JSBgwDnjXGOOBdZRQApwPRAIPGGM6nPACxkw1xsQaY2IzMjLO6A0opZSquZqEfirQzul+GBXdN2VuAz4GEJHVgB8QAtwIfCsiRSKSDvwKnDCOVEReE5EYEYkJDQ0983ehlGpUZsyYUT6Z2pkYPHjwadePGzeOI0eO1LasE7z11lvs3181Dt1bTUJ/LdDZGBNpjPHBOlG7qEqbZGAUgDGmG1boZziWX2wsAcAgYHtdFa+UUs5WrVp12vWLFy8mKCiozl7vdKFfUlJSZ69Tl6odsikixcaYe4DvsIZjzhORLcaYmUCsiCwCHgBeN8bcj9X1M1lExBgzB5gPbMbqJpovIhtP/kpKKbew5FE4uKlun7NNLxj7/GmbPPPMM7zzzju0a9eO0NBQ+vXrx+7du086B31aWhrTpk0jMTERgLlz5zJ48GCaNm1Kbm4uBw4cYMKECWRnZ5dPdzxs2DAiIiKIjY0lJCTklPPsjx07lqFDh7Jq1Sratm3Ll19+SZMmTU6o99NPPyU2NpY//elPNGnShNWrV9OtW7cazaGfkZHBtGnTyqeAfumll2o03URdqNE4fRFZjHWC1nnZk063twInVCwiuVjDNpVS6pTWrVvHggULiIuLo7i4mOjoaPr168fUqVN59dVX6dy5M2vWrOGuu+7ixx9/5N5772XEiBF8/vnnlJSUkJubW+n5PvjgA0aPHs1f//pXSkpKyM/PP+H1yubZFxEGDhzIiBEjCA4OZteuXXz44Ye8/vrr/PGPf2ThwoVMmjTphJqvv/56XnnlFWbNmkVMTEWvddkc+gCjRo06af3Tp0/n/vvvZ+jQoSQnJzN69Gi2bdtWD5/sifSKXKVUZdXskdeHX375hWuuuQZ/f38ArrzySgoKCk45B/2PP/7IO++8A1D+jVPO+vfvz6233kpRURFXX301UVFRldZXN89+Wft+/frVeG79MjWZQ3/p0qVs3bq1fHl2djY5OTkEBgae0WvVhoa+UsotVJ0P/3Rz0Fdn+PDhrFixgm+++YabbrqJhx56iJtvvrl8fU3n2ff09OTYsWNn9No1mUO/tLSU1atXn7TbqL7phGtKKZcbPnw4n3/+OceOHSMnJ4evvvoKf3//U85BP2rUqPKvJSwpKSE7u/I1n3v37qVVq1bccccd3Hbbbaxfv/6E16uLefZPN+/96ebQv+yyy3jllVfK29b3F6c409BXSrlcdHQ0EyZMICoqiuuuu648gE81B/3s2bNZvnw5vXr1ol+/fmzZsqXS8/30009ERUXRt29fFi5cyPTp0094vbqYZ3/y5MlMmzaNqKiokx4RnKr+l19+mdjYWHr37k337t159dVXz/i1a0vn01dK6Xz6DYzOp6+UUqpG9ESuUkpV4+677+bXX3+ttGz69OkNcuplDX2lFGCdaKw6gkZZ3OmLU862S167d5RS+Pn5kZWVddaBouqXiJCVlYWfn1+tn0P39JVShIWFkZqais5y6/78/PwICwur9eM19JVSeHt7ExkZ6eoy1Dmg3TtKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjGvpKKWUjNQp9Y8wYY8wOY0yCMebRk6wPN8YsN8bEGWM2GmPGOa3rbYxZbYzZYozZZIyp/Tf6KqWUOivVfkeuMcYTmANcCqQCa40xi0Rkq1OzvwEfi8hcY0x3YDEQYYzxAt4DbhKRDcaYlkBRnb8LpZRSNVKTPf0BQIKIJIpIIbAAuKpKGwGaOW43B/Y7bl8GbBSRDQAikiUiJWdftlJKqdqoSei3BVKc7qc6ljmbAUwyxqRi7eX/xbG8CyDGmO+MMeuNMQ+fZb1KKaXOQk1C35xkmVS5PxF4S0TCgHHAu8YYD6zuo6G53EEhAAAOXklEQVTAnxy/rzHGjDrhBYyZaoyJNcbEZmRknNEbUEopVXM1Cf1UoJ3T/TAqum/K3AZ8DCAiqwE/IMTx2J9FJFNE8rGOAqKrvoCIvCYiMSISExoaeubvQimlVI3UJPTXAp2NMZHGGB/gBmBRlTbJwCgAY0w3rNDPAL4Dehtj/B0ndUcAW1FKKeUS1Y7eEZFiY8w9WAHuCcwTkS3GmJlArIgsAh4AXjfG3I/V9TNZRAQ4bIx5EWvDIcBiEfmmvt6MUkqp0zNWNruPmJgYiY2NdXUZSinVoBhj1olITHXt9IpcpZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ19pZSyEQ39ulZSDEm/QvZ+V1eilFIn8KpJI2PMGGA24Am8ISLPV1kfDrwNBDnaPCoii6us3wrMEJFZdVS7+zm4Gb68Gw7EW/ebt4Ow/tBugPXTuhd4+bi2RqWUrVUb+sYYT2AOcCmQCqw1xiwSka1Ozf4GfCwic40x3YHFQITT+n8DS+qsandTXAgrX4QVs8CvOfzhZSjMg9TfIeV32PKZ1c7LD87va20AwhwbgqatXFu7UspWarKnPwBIEJFEAGPMAuAqrD33MgI0c9xuDpT3bRhjrgYSgby6KNjt7I+DL++BtM3QazyMeQECWjpW3mX9OrrPsQFYCylrYPV/oXS2tS44omIDENYfWvcEzxodgLmXkmI4shcyd0HmTusnK8H67RcEl/0dLhgHxri6UqVsrSbp0hZIcbqfCgys0mYG8L0x5i9AAHAJgDEmAHgE6yjhwVO9gDFmKjAVIDw8vIalu1hRAfz8Avw6GwJC4YYPoeu4k7dt3haaXwM9rql47IEN1gYg9XfY8zNs+tha5+0Pbfs5uoUGWr/LNyJuoCAbsnY5hbvj9qHdUFJY0S4gFEK6QNcrrKOdBTdCp0th7AvQsqPr6lfK5moS+ifbNZMq9ycCb4nIv4wxFwLvGmN6Ak8D/xaRXHOaPTwReQ14DSAmJqbqc7uf1Fj44i7I3AFRk2D0P6BJcM0f7+0H4QOtHwAROJIMqWutgExZY21MpMRa36KjtQFo1986KmjVDTw86/59lSkthex9lffWywI+50BFO+MJLTpY4d7lMut3SBdo2Qn8W1S0KymC31+Hn56D/w6CwX+BYQ+AT0D9vQel1EkZkdNnrCPEZ4jIaMf9xwBE5DmnNluAMSKS4rifCAwCFgLtHM2CgFLgSRF55VSvFxMTI7GxsbV+Q/Wq6Bj8+A/47b8QeD78YTZ0vqR+Xqsw3+o6SllTsTHIz7TW+QRCWD9Ht9BA6/aZbHTKFB2DrN1Oe+xO3TJF+RXtfJtDSGdHqHeuCPfgiDM7MZ2TBj88CRsXQLMwGP0MdL9Ku3yUqgPGmHUiElNtuxqEvhewExgF7APWAjeKyBanNkuAj0TkLWNMN2AZ0FacntwYMwPIrW70jtuG/t5VVt/9od3QbwpcOhP8mlX/uLoiAocSnY4Gfof0LSCl1vqQCypGCYUNsELZw8N6XF5G5b31sttHUqg4aDMQ1K4i0J3DPSC0boN572pY/BCkbYIOI2HsPyG0S909v1I2VGeh73iyccBLWMMx54nIM8aYmUCsiCxyjNh5HWiKlSIPi8j3VZ5jBg0x9I/nwrKZ8PtrEBQOV/4HOoxwdVWW4zmwb721AUj93dogHDtsrfNrbu2JH06CgqMVj/H2t7pfKoV7Z6sLycf/3NVeUgyx86wjp6I8GHQXjHgYfAPPXQ1KNSJ1GvrnkluFfuLPsOgv1qiUAX+GUU+Cb1NXV3VqIlbXTMoaa0NwNAWCIyvvuTdrax0BuIvcDFg2A+Leg8Dz4LJ/QM/rtMtHqTOkoX82CrLhhydg3VvWHvBVr0D7wa6tqbFLWQuLH7BGNUUMg7H/B627u7oq1ViJwNFUa6j1wc2Qlw7+IRAQYnVnBoRa19AEhIBvswaxE1LT0G+AA8Lr2a6l8NV0yNlvjTIZ+fi57fawq3b94Y7lsP5tqzvt1aEwcBqMfMTqqlKqtoqOQfo2SNtSEfJpm6HgSEUb3+Zw/OjJH+/p69gQOG8QQituB4RAQKuK257e5+Z91ZLu6Zc5dhi++yvEv2+dFL36vxBW7UZT1Yf8Q7DsaVj3tvUf6bK/Q+8JDWJvS7mQCOQcdAT7poqAz0qoGP7sHWAdQbbuCW16WlOjtO5unUsqLoT8LGvgQ9WfXOf7mdaRgfN1Kc6aBJ96g9C0VeV1dXgUod07Z2L7Yvj6fusfdOh9MOIR8PI9tzWoE+1bD4sfhH3rIPxCGPdPaNPL1VUpd1BcCBnbK++5p222QrtM83BHsPes+B0cWTfntETgeLa1AchNP3GDUHa7bJ3zUYWzqkcR5/WBUU/UqiTt3qmJvCz49hHY9In1B3HjR3B+lKurUmXaRsNtSyH+PVg6A/43HPrfARc9Dk2CXF2dOldyM6zhvWXhfnCzdWFkabG13svPumDxgnHWTkHrntC6R/3+jRhjdTv6Na/ZFeblRxFOG4S8DMdGwXE7L90aEl7P7Lunv+ULay/y2BEY/hAMvV9nwHRnxw7Dj89A7JvQpAVc+jT0udG9RiKps1NSZF1Hkra58h58blpFm8Dzq+y997KuCm+I81XVMe3eOZXcdCvst34J50VZffete9Tf66m6dWCj9e+Xssaal2jcP62ZS1XDIWIFefo2xwlWRx98xvaKfnJPHwi9wAr1spBv3dO95qFyMxr6VYnApk9hycPWtMcjH4XB9+oeQkMkAhsWWFM65GVAzBS4+InK8/0o95CbARnbIH07pG+1gj19W+U+7oBWTnvvju6ZkM5uPwrG3WifvrPsA9aJ2p1LrL3Dq+ZYexGqYTIGoiZas5r+9Dys+Z/VXTfqSYi+uX4no1Mnl3/IsedeFuzbrbB3PrHq1xxadbdmm23VDUK7Wr/1OyXOqca9py9iDcH89nEoOW7tDQ66U0OhsUnbYs3ls/dXq6tn3L+sSehU3Tt2uCLQnX/npVe08W3mCPSuENqt4ndgGx12W490T/9IinWR1e5l0H6INWeOzuPeOLXuAZO/gc0LrWst3hgF0TfBqKesoXDqzBVkV3TFOP92nlrbp6l1xNz5ssoB36ythrsba3yhLwLr5sP3T1ozUI6bBTG36SiPxs4Y6HU9dBltfbnNb3Otk/UXPwExt+rR3akcz4WMHY49dqeAz95X0cariRXuHUZWdMm06mZNj63/rxqcxtW9c2gPfHUv7Flh/YH+4WUIbl+X5amGImOH1eWz52fr5OC4f1V8ac3ZEHH8lFg7FaWO31LqWCZVljutL3WsLy12+imB0qITl5UUVWlT7NSupGJdSVHl+6VV7xdbM5pWff7iY5CZAEeTK96bp681xXVot4pgD+0KQe013BsA+43eydxlXbzj4WXN1Bh9sx5i2p0IbP3C6vLJ3mdNKV0WytUGdGmVMHe0O+FL49yIh7c14sXDyzqy8fCyljnf9/S2bnv6WFNvt+pWEfLBEXpE1IDZr0+/ZSfrAquoG6F5mKurUe7AGGukSOfLYNUrVheG8XD8eFq/PTxOsszTaZnHSZad5rHlbc1JljndLg9jRyB7VrlfdX15YJ+qje6Jq5ppPKFvjPUlHEpV5RNgzdaplEJ3D5RSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykY09JVSykbcbhoGY0wGsPcsniIEyKyjcho6/Swq08+jMv08KjSGz6K9iIRW18jtQv9sGWNiazL/hB3oZ1GZfh6V6edRwU6fhXbvKKWUjWjoK6WUjTTG0H/N1QW4Ef0sKtPPozL9PCrY5rNodH36SimlTq0x7ukrpZQ6hUYT+saYMcaYHcaYBGPMo66ux5WMMe2MMcuNMduMMVuMMdNdXZOrGWM8jTFxxpivXV2Lqxljgowxnxpjtjv+Ri50dU2uZIy53/H/ZLMx5kNjjJ+ra6pPjSL0jTGewBxgLNAdmGiM6e7aqlyqGHhARLoBg4C7bf55AEwHtrm6CDcxG/hWRLoCfbDx52KMaQvcC8SISE/AE7jBtVXVr0YR+sAAIEFEEkWkEFgAXOXimlxGRA6IyHrH7Rys/9RtXVuV6xhjwoDLgTdcXYurGWOaAcOBNwFEpFBEjri2KpfzApoYY7wAf2C/i+upV40l9NsCKU73U7FxyDkzxkQAfYE1rq3EpV4CHgZKXV2IG+gAZADzHd1dbxhjAlxdlKuIyD5gFpAMHACOisj3rq2qfjWW0DcnWWb7YUnGmKbAQuA+Ecl2dT2uYIy5AkgXkXWursVNeAHRwFwR6QvkAbY9B2aMCcbqFYgEzgcCjDGTXFtV/WosoZ8KtHO6H0YjP0SrjjHGGyvw3xeRz1xdjwsNAa40xiRhdftdbIx5z7UluVQqkCoiZUd+n2JtBOzqEmCPiGSISBHwGTDYxTXVq8YS+muBzsaYSGOMD9aJmEUurslljDEGq892m4i86Op6XElEHhORMBGJwPq7+FFEGvWe3OmIyEEgxRhzgWPRKGCrC0tytWRgkDHG3/H/ZhSN/MS2l6sLqAsiUmyMuQf4Duvs+zwR2eLislxpCHATsMkYE+9Y9riILHZhTcp9/AV437GDlAhMcXE9LiMia4wxnwLrsUa9xdHIr87VK3KVUspGGkv3jlJKqRrQ0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRvR0FdKKRv5/5Yq8G830UTCAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df.plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision Tree Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8343173330831328" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(dt_cv_scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Random Forest Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9223850187122359" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(rf_cv_scores)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_science/ensemble/voting.ipynb b/data_science/ensemble/voting.ipynb new file mode 100755 index 0000000..e42e8d3 --- /dev/null +++ b/data_science/ensemble/voting.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Voting\n", + "Based on the idea that classifiers can complement each other, \n", + "Aggregating individual classifier's prediction to make better prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import datasets\n", + "from sklearn import tree\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.svm import SVC\n", + "from sklearn.ensemble import VotingClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# load mnist dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "mnist = datasets.load_digits()\n", + "features, labels = mnist.data, mnist.target\n", + "X_train,X_test,y_train,y_test=train_test_split(features,labels,test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# single classifiers accuracy on mnist\n", + "build decision tree, knn, svm and check accuracy on MNIST data." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "dtree = tree.DecisionTreeClassifier(\n", + " criterion=\"gini\", max_depth=8, max_features=32,random_state=35)\n", + "\n", + "dtree = dtree.fit(X_train, y_train)\n", + "dtree_predicted = dtree.predict(X_test)\n", + "\n", + "knn = KNeighborsClassifier(n_neighbors=299).fit(X_train, y_train)\n", + "knn_predicted = knn.predict(X_test)\n", + "\n", + "svm = SVC(C=0.1, gamma=0.003,\n", + " probability=True,random_state=35).fit(X_train, y_train)\n", + "svm_predicted = svm.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[accuarcy]\n", + "d-tree: 0.7972222222222223\n", + "knn : 0.8416666666666667\n", + "svm : 0.85\n" + ] + } + ], + "source": [ + "print(\"[accuarcy]\")\n", + "print(\"d-tree: \",accuracy_score(y_test, dtree_predicted))\n", + "print(\"knn : \",accuracy_score(y_test, knn_predicted))\n", + "print(\"svm : \",accuracy_score(y_test, svm_predicted))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "we can easily do soft voting or hard voting using sklearn's voting classifier \n", + "when you want to implement soft voting by scratch, you can use predict_proba just like below, \n", + "Below is the example of SVM's prediction (digit 0 to 9) on two MNIST data." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[9.95557918e-01 3.42018637e-04 4.57700824e-04 4.19160266e-04\n", + " 4.21146304e-04 7.99436984e-04 4.11439277e-04 6.08753549e-04\n", + " 4.33211441e-04 5.49214707e-04]\n", + " [2.86586264e-03 4.17512273e-03 4.28013091e-03 4.14650212e-03\n", + " 9.27814553e-01 2.24791840e-02 3.06764221e-03 9.50855980e-03\n", + " 1.51437526e-02 6.51868962e-03]]\n" + ] + } + ], + "source": [ + "svm_proba = svm.predict_proba(X_test)\n", + "print(svm_proba[0:2])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# hard voting\n", + "hard voting is just majority vote which collects each classifier's prediction and take the most voted prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", + " if diff:\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9083333333333333" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voting_clf = VotingClassifier(estimators=[\n", + " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", + " weights=[1,1,1], voting='hard').fit(X_train, y_train)\n", + "hard_voting_predicted = voting_clf.predict(X_test)\n", + "accuracy_score(y_test, hard_voting_predicted)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# soft voting\n", + "soft voting takes each classifier's predict_proba and then sum up all probabilities to take the prediction has highest probabilities." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda3/envs/wikiml/lib/python3.6/site-packages/sklearn/preprocessing/label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.\n", + " if diff:\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9138888888888889" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "voting_clf = VotingClassifier(estimators=[\n", + " ('decision_tree', dtree), ('knn', knn), ('svm', svm)], \n", + " weights=[1,1,1], voting='soft').fit(X_train, y_train)\n", + "soft_voting_predicted = voting_clf.predict(X_test)\n", + "accuracy_score(y_test, soft_voting_predicted)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization\n", + "we can visualize accuracy to check voting result is stabled or better than single model accuracy. \n", + "it is hard to say which voting is better, but we can confirm classifiers complement each other, \n", + "and voting result is better in this example." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEepJREFUeJzt3XvQHXV9x/H3h2BEES8lqVUghiqoqVaoGbwgikpbwAo4oEK1LQ6V6QVtvc3QwTIWrVXROrViK7SKYpWLiqYYDZWKUK2YIBdJMDQTUFLaMSpSURGRb//YjZwcT/Kc58l58iQ/3q+ZzLOX39n97e5vP2fP75zdpKqQJLVll7mugCRp8gx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoN2nasVL1iwoBYvXjxXq5ekndLVV1/9napaOFW5OQv3xYsXs2rVqrlavSTtlJJ8c5xydstIUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KD5uwOVUmajsWnfmauqzAxt7ztBbO+DsNd2om0EnDbI9zu7+yWkaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDfHCYdiqtPDgLfHiWZpdX7pLUIK/cd0KtXL165SrNHq/cJalBhrskNchwl6QGGe6S1CDDXZIaNFa4Jzk8ydok65KcOmL+oiRfSHJNkuuTHDn5qkqSxjVluCeZB5wFHAEsAU5IsmSo2BuBC6vqQOB44H2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJkCHtoPPwy4bXJVlCRN1zg3Me0F3DowvgF42lCZNwGXJnkVsDtw2ERqJ0makXHCPSOm1dD4CcC5VfWuJM8AzkvypKq6d7MFJScDJwMsWrRoJvUF2rlDE7xLU9LsGKdbZgOwz8D43vxit8tJwIUAVfWfwG7AguEFVdXZVbW0qpYuXLhwZjWWJE1pnHBfCeyXZN8k8+m+MF02VOZbwPMBkjyRLtw3TrKikqTxTRnuVXUPcAqwAriR7lcxq5OckeSovtjrgFcmuQ74GHBiVQ133UiStpOxngpZVcuB5UPTTh8YXgMcPNmqSZJmyjtUJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgscI9yeFJ1iZZl+TULZR5SZI1SVYn+ehkqylJmo5dpyqQZB5wFvCbwAZgZZJlVbVmoMx+wF8AB1fV7Ul+ebYqLEma2jhX7gcB66pqfVXdDZwPHD1U5pXAWVV1O0BVfXuy1ZQkTcc44b4XcOvA+IZ+2qD9gf2TfCnJV5IcPqkKSpKmb8puGSAjptWI5ewHHArsDVyZ5ElV9f3NFpScDJwMsGjRomlXVpI0nnGu3DcA+wyM7w3cNqLMp6vqp1V1M7CWLuw3U1VnV9XSqlq6cOHCmdZZkjSFccJ9JbBfkn2TzAeOB5YNlfkU8FyAJAvoumnWT7KikqTxTRnuVXUPcAqwArgRuLCqVic5I8lRfbEVwHeTrAG+ALyhqr47W5WWJG3dOH3uVNVyYPnQtNMHhgt4bf9PkjTHvENVkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqVspd1ySSrJ0clWUJE3XlOGeZB5wFnAEsAQ4IcmSEeX2AF4NXDXpSkqSpmecK/eDgHVVtb6q7gbOB44eUe7NwDuAuyZYP0nSDIwT7nsBtw6Mb+in/VySA4F9quqSrS0oyclJViVZtXHjxmlXVpI0nnHCPSOm1c9nJrsA7wZeN9WCqursqlpaVUsXLlw4fi0lSdMyTrhvAPYZGN8buG1gfA/gScDlSW4Bng4s80tVSZo744T7SmC/JPsmmQ8cDyzbNLOq7qiqBVW1uKoWA18BjqqqVbNSY0nSlKYM96q6BzgFWAHcCFxYVauTnJHkqNmuoCRp+nYdp1BVLQeWD007fQtlD932akmStoV3qEpSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQWOFe5LDk6xNsi7JqSPmvzbJmiTXJ7ksyWMmX1VJ0rimDPck84CzgCOAJcAJSZYMFbsGWFpVvw58HHjHpCsqSRrfOFfuBwHrqmp9Vd0NnA8cPVigqr5QVT/qR78C7D3ZakqSpmOccN8LuHVgfEM/bUtOAj47akaSk5OsSrJq48aN49dSkjQt44R7RkyrkQWTlwNLgTNHza+qs6tqaVUtXbhw4fi1lCRNy65jlNkA7DMwvjdw23ChJIcBpwHPqaqfTKZ6kqSZGOfKfSWwX5J9k8wHjgeWDRZIciDwfuCoqvr25KspSZqOKcO9qu4BTgFWADcCF1bV6iRnJDmqL3Ym8BDgoiTXJlm2hcVJkraDcbplqKrlwPKhaacPDB824XpJkraBd6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUFjhXuSw5OsTbIuyakj5j8wyQX9/KuSLJ50RSVJ45sy3JPMA84CjgCWACckWTJU7CTg9qp6HPBu4O2TrqgkaXzjXLkfBKyrqvVVdTdwPnD0UJmjgQ/1wx8Hnp8kk6umJGk6xgn3vYBbB8Y39NNGlqmqe4A7gD0nUUFJ0vTtOkaZUVfgNYMyJDkZOLkfvTPJ2jHWP5cWAN+ZzRVkx+3Acttn2f15++/P2w7bvP2PGafQOOG+AdhnYHxv4LYtlNmQZFfgYcD3hhdUVWcDZ49TsR1BklVVtXSu6zEX3Pb757bD/Xv7W9r2cbplVgL7Jdk3yXzgeGDZUJllwB/0w8cB/15Vv3DlLknaPqa8cq+qe5KcAqwA5gEfqKrVSc4AVlXVMuCfgfOSrKO7Yj9+NistSdq6cbplqKrlwPKhaacPDN8FvHiyVdsh7DRdSLPAbb//uj9vfzPbHntPJKk9Pn5Akhq004R7kjclef0MX/vlKeYvT/LwmdVss+UcM+Lu3R1SksVJbpjremjrZuM4JbklyYJtXMbDk/zJwPijk3x822s3GUkOSbI6ybVJnpjkdye03AOSHDkwftSoR7LsCHaacN8WVfXMKeYfWVXfn8CqjqF7RMMv6H8iKm1Xs9juHg78PNyr6raqOm6W1jUTLwPeWVUHAI8EJhLuwAHAz8O9qpZV1dsmtOzJqqod9h9wGrAW+DzwMeD1/fTHAp8DrgauBJ7QT38kcDFwXf/vmf30O/u/jwKuAK4FbgAO6affAizoh1/bz7sB+PN+2mLgRuAcYDVwKfCgobo+k+6XQjf3y38scDnwVuCLwOuAhcAn6H5euhI4uH/t7sAH+mnXAEdvh327GLihH/7Vfr1vAD7Z79v/At4xUP5O4K/7/foV4JFz3T5muN27A5/pt+MGup/wXjgw/1DgXwe2+e19O/s83aM4LgfWA0dtp/puse0Br+zbzHV9u3pwP/1c4G+BLwDvortb/NL+GL8f+Oam9j6wnj8eOt4nAn+/lXPifODHfVs/c6g9nbiVdnQScFO/H88B3jvD4/bSfvrz++36en8OPRD4Q+47F/+lb6939HV9zdByLwCOHBg/FzgW2A34YL/ca4DnAvOBbwEb+2W9tN/W9w689j3Al/s2clw/fRfgff3xu4TuxynHzXrbmeuTbSsH86n9jn0w8FBgHfeF+2XAfv3w0+h+V7/pQG1qfPOAh206Sfu/rwNOG5i/Rz98C92daZvWuTvwkP5gHNg33HuAA/ryFwIvH1HncwcPWt+A3zcw/lHgWf3wIuDGfvitm5ZHd0V0E7D7LO/fxf1J8vi+8R7QN9T1dDeh7UYXAvv05Qt4YT/8DuCNc91GZrjdxwLnDIw/rD9hd+/H/2HgWBRwRD98MV1APgB4CnDtdqrvFtsesOdAubcArxpoh5cA8/rx9wCn98Mv6LdrONwX0j1DatP4Z4FnTXFO3DDcnvrhke0IeDTdufZL/X68kvHDfdRx243usSf799M+zH3n/7ncF66HApdsYbkvAj7UD8/vl/cguqz4YD/9CX0b2Y2BMB/Y1sFwv4guzJds2p909/4s76f/CnA72yHcd+RumUOAi6vqR1X1f/Q3TiV5CN1V8kVJrqW7EnlU/5rn0Z2cVNXPquqOoWWuBF6R5E3Ak6vqB0Pzn9Wv84dVdSfd1cch/bybq+rafvhqusY8jgsGhg8D3tvXexnw0CR7AL8FnNpPv5yuES0ac/nbYiHwabqw2LRtl1XVHdX9vHUN993qfDddYMD0tn9H83XgsCRvT3JI30Y+B7yw78J4Ad0+gW6bPzfwui9W1U/74cXbsc5bantPSnJlkq/TdUP82sBrLqqqn/XDzwY+AlBVn6ELl81U1UZgfZKnJ9mT7k3/S2z9nNiaUe3oILp9+L1+P1405vbD6OP2eLp9c1Nf5kP9tk7HZ4HnJXkg3ZNvr6iqH9Nt93kAVfUNujeo/cdY3qeq6t6qWkPXk0C/rIv66f9L94lq1u3o/cCjfqe5C/D96vrSprewqiuSPJvuBD4vyZlV9eGBIlt7kuVPBoZ/RvfuPo4fDgzvAjyjbzz3rbR7guaxVbW9n7VzB92VysF0V2Twi9u5qY38tPrLkKHpO5WquinJU+n6Tf8myaV0b8B/SvdRfuXAm/7gNt9Lv2+q6t7t/B3KltreucAxVXVdkhPprlA3GWx3MPpcGnYB8BLgG3SBXtvwdNdR7WjGT4rdwnEbvlN+Jsu9K8nlwG/TdbN8rJ81ie3O0N/take+cr8CeFGSB/VXty8E6K/ib07yYuiCMclT+tdcRtd3SJJ5SR46uMAkjwG+XVXn0N1V+xsj1nlMkgcn2Z3uI9uV06jzD4A9tjL/UuCUgfpseoNaAbxq04mU5MBprHNb3E33JfDvT+rXBDu6JI8GflRVHwHeSdcGLu//vpLNP2nt6PYA/ifJA+iu3Lfkik3zkxwBPGIL5T5J1x5O4L79sKVzYqq2PspXgeckeUT/5njsuC/cwnH7BrA4yeP6Yr9H9/3WsKnqej7wCrpPJCv6aYP7bH+6T9Jrx1jWKP8BHJtklySPZPM34Vmzw4Z7VX2NroFdS/dl0WDIvgw4Kcl1dFecm54v/2fAc/uPqVez+cdU6HbqtUmuoWtYfzdinefSNcKrgH+qqmumUe3zgTckuSbJY0fMfzWwNMn1SdYAf9RPfzNdH+T1/c/e3jyNdW6Tqvoh8DvAa+j6MVv3ZOCrfRfYacBb+u6LS+g+ll+ytRfvYP6Srp3+G13QbclfAc9O8jW6LsBvjSpUVbfTd6FU1Vf7aSPPiar6LvClJDckOXOcylbVf9N9v3QV3RfUa+g+PY5j1HG7iy6UL+rP+XuBfxzx2uuBe5Jcl+Q1I+ZfSted8/nq/s8K6L4Andcv9wLgxKr6CV2XypL+J5YvHbPun6B7uOINdN3IVzH+ds+Yd6hK2m6SPKSq7uyv3C+me1bVxXNdr9k2sN170r1RHtz3v8+anbLfVNJO601JDqP70cClwKfmuD7byyX9jZLzgTfPdrCDV+6S1KQdts9dkjRzhrskNchwl6QGGe6S1CDDXZIaZLhLUoP+H47Jp0tra/pcAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "%matplotlib inline\n", + "\n", + "x = np.arange(5)\n", + "plt.bar(x, height= [accuracy_score(y_test, dtree_predicted),\n", + " accuracy_score(y_test, knn_predicted),\n", + " accuracy_score(y_test, svm_predicted),\n", + " accuracy_score(y_test, hard_voting_predicted),\n", + " accuracy_score(y_test, soft_voting_predicted)])\n", + "plt.xticks(x, ['decision tree','knn','svm','hard voting','soft voting']);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_science/nlp/word2vec_gensim.ipynb b/data_science/nlp/word2vec_gensim.ipynb new file mode 100644 index 0000000..5024407 --- /dev/null +++ b/data_science/nlp/word2vec_gensim.ipynb @@ -0,0 +1,229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# pretrained Word2Vec download" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2020-06-03 23:13:08-- https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\n", + "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.16.238\n", + "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.16.238|:443... connected.\n", + "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", + "\n", + " The file is already fully retrieved; nothing to do.\n", + "\n" + ] + } + ], + "source": [ + "!wget -P . -c \"https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import gensim" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# load pretrained word2vec\n", + "model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('pal', 0.7476358413696289),\n", + " ('friends', 0.7098034620285034),\n", + " ('buddy', 0.6972494125366211),\n", + " ('dear_friend', 0.6960037350654602),\n", + " ('acquaintance', 0.6843010187149048)]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# similar words\n", + "model.most_similar(positive=['friend'], topn=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('queen', 0.7118192911148071)]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# king + woman - man = queen\n", + "model.most_similar(positive=['king', 'woman'], negative=['man'], topn=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "300" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Word2Vec vector dimension\n", + "len(model['friend'])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.07080078, -0.21386719, 0.15332031, 0.09423828, -0.03442383,\n", + " 0.43359375, -0.16503906, -0.05786133, 0.17578125, -0.08203125,\n", + " 0.24511719, -0.19335938, -0.0255127 , -0.09619141, -0.125 ,\n", + " 0.02575684, 0.16796875, -0.03759766, 0.09472656, -0.04760742,\n", + " 0.20605469, 0.31835938, 0.15917969, -0.17089844, 0.09033203,\n", + " -0.1640625 , -0.15234375, 0.3125 , 0.06298828, -0.24902344,\n", + " 0.15625 , -0.04516602, -0.12890625, -0.00686646, -0.02160645,\n", + " 0.14453125, 0.2734375 , 0.12695312, 0.10742188, 0.11376953,\n", + " 0.14355469, -0.00173187, 0.22851562, -0.03515625, 0.17089844,\n", + " 0.04516602, -0.07958984, -0.08886719, -0.01342773, -0.09667969,\n", + " -0.12597656, 0.10595703, 0.15332031, -0.03808594, 0.02246094,\n", + " 0.01428223, -0.03295898, 0.20703125, -0.03417969, 0.02233887,\n", + " 0.00244141, 0.13476562, -0.01403809, 0.13378906, 0.0201416 ,\n", + " 0.14746094, 0.00759888, -0.18652344, 0.16113281, 0.109375 ,\n", + " 0.14355469, 0.01623535, 0.01867676, 0.09179688, -0.33789062,\n", + " 0.19335938, -0.29101562, -0.00860596, 0.10644531, 0.359375 ,\n", + " 0.25585938, -0.03320312, 0.15625 , -0.24316406, -0.06738281,\n", + " 0.09033203, -0.125 , 0.21777344, -0.02380371, -0.06445312,\n", + " -0.14355469, 0.05664062, -0.12597656, 0.02172852, 0.03833008,\n", + " -0.17578125, -0.08349609, 0.21386719, -0.01855469, -0.23535156,\n", + " -0.14746094, -0.16113281, -0.03125 , -0.10107422, 0.07080078,\n", + " 0.01135254, -0.04370117, 0.07666016, 0.16503906, 0.04541016,\n", + " -0.13867188, 0.13085938, 0.13378906, -0.14453125, 0.12792969,\n", + " -0.06787109, -0.04296875, -0.03369141, 0.10302734, 0.22949219,\n", + " 0.14160156, -0.01153564, -0.00086212, -0.10449219, -0.03710938,\n", + " 0.01928711, 0.16699219, -0.06079102, 0.09814453, 0.0703125 ,\n", + " -0.39648438, -0.23242188, -0.04077148, 0.09570312, -0.0546875 ,\n", + " -0.09814453, 0.09082031, 0.03588867, 0.09228516, 0.3125 ,\n", + " 0.10595703, 0.18847656, -0.11230469, 0.00842285, 0.08935547,\n", + " 0.04663086, -0.25 , -0.03369141, 0.03808594, -0.03710938,\n", + " 0.42773438, 0.10839844, -0.01391602, -0.01965332, -0.04296875,\n", + " -0.11035156, 0.0390625 , 0.04541016, -0.20019531, -0.14355469,\n", + " -0.14257812, 0.03662109, 0.25 , 0.3671875 , -0.12304688,\n", + " -0.0859375 , 0.24902344, -0.21582031, 0.02648926, 0.17871094,\n", + " 0.29296875, 0.21582031, 0.1015625 , 0.00167084, -0.07177734,\n", + " 0.03686523, 0.22851562, -0.125 , 0.17285156, 0.22265625,\n", + " 0.21191406, 0.03686523, 0.09570312, -0.00344849, 0.13183594,\n", + " -0.23925781, 0.00576782, 0.27148438, 0.10400391, 0.0098877 ,\n", + " -0.24511719, 0.21777344, -0.03027344, 0.23046875, 0.11816406,\n", + " 0.1640625 , -0.00109863, 0.00349426, -0.02197266, -0.09179688,\n", + " -0.10351562, 0.06933594, -0.13476562, -0.06201172, 0.14355469,\n", + " -0.10888672, -0.11328125, 0.2109375 , -0.10839844, -0.18261719,\n", + " -0.06689453, -0.265625 , -0.13378906, -0.04296875, -0.17773438,\n", + " 0.00689697, -0.00982666, -0.00640869, -0.12792969, 0.08203125,\n", + " -0.01367188, 0.02734375, 0.12597656, -0.00772095, -0.04614258,\n", + " -0.12255859, 0.16210938, 0.28320312, 0.04296875, -0.05175781,\n", + " -0.16210938, 0.14648438, -0.18359375, -0.24511719, 0.22167969,\n", + " 0.0546875 , -0.10302734, -0.07763672, -0.33984375, -0.05908203,\n", + " -0.0022583 , -0.11962891, -0.3046875 , 0.02233887, 0.02941895,\n", + " 0.37695312, -0.01721191, -0.05932617, 0.30273438, -0.13574219,\n", + " 0.14746094, 0.17089844, 0.16015625, 0.21484375, 0.01013184,\n", + " 0.06738281, -0.12109375, -0.12304688, -0.20117188, 0.02880859,\n", + " -0.00662231, -0.20410156, 0.02001953, -0.15136719, 0.16699219,\n", + " 0.14160156, -0.02331543, 0.14550781, -0.13476562, 0.04785156,\n", + " 0.14160156, 0.03808594, -0.12109375, 0.02770996, -0.0123291 ,\n", + " -0.20410156, -0.06445312, 0.06079102, -0.07519531, -0.28125 ,\n", + " 0.18261719, -0.25390625, -0.0456543 , 0.14160156, -0.0546875 ,\n", + " -0.01477051, -0.38085938, 0.14355469, 0.12255859, 0.14941406,\n", + " -0.03320312, 0.19433594, -0.34375 , -0.24902344, -0.00331116,\n", + " -0.05639648, -0.00079727, -0.21679688, -0.01977539, 0.10644531],\n", + " dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# print word2vec\n", + "model['friend']" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_science/svm/svm.ipynb b/data_science/svm/svm.ipynb new file mode 100755 index 0000000..db0733a --- /dev/null +++ b/data_science/svm/svm.ipynb @@ -0,0 +1,401 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.metrics import classification_report\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.svm import SVC" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [], + "source": [ + "# load iris data\n", + "dataset = load_iris()\n", + "\n", + "# use 80% as train data, 20% as test data\n", + "X_train,X_test,y_train,y_test=train_test_split(dataset.data,dataset.target,test_size=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Find best hyperparamters\n", + "RBF kernel SVM has two parameters.\n", + "1. C (cost): The C parameter trades off correct classification of training examples against maximization of the decision function’s margin. For larger values of C, a smaller margin will be accepted if the decision function is better at classifying all training points correctly. \n", + "\n", + "2. gamma: the gamma parameter defines how far the influence of a single training example reaches, with low values meaning ‘far’ and high values meaning ‘close’. The gamma parameters can be seen as the inverse of the radius of influence of samples selected by the model as support vectors.\n", + "\n", + "reference:\n", + "http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Grid Search\n", + "find best hyperparameter using grid search." + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "def svc_param_selection(X, y, nfolds):\n", + " svm_parameters = [\n", + " {'kernel': ['rbf'],\n", + " 'gamma': [0.00001,0.0001, 0.001, 0.01, 0.1, 1],\n", + " 'C': [0.01, 0.1, 1, 10, 100, 1000]\n", + " }\n", + " ]\n", + " \n", + " clf = GridSearchCV(SVC(), svm_parameters, cv=10)\n", + " clf.fit(X_train, y_train)\n", + " print(clf.best_params_)\n", + " \n", + " return clf" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 10, 'gamma': 0.1, 'kernel': 'rbf'}\n" + ] + } + ], + "source": [ + "clf = svc_param_selection(X_train, y_train, 10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 1.00 1.00 7\n", + " 1 1.00 1.00 1.00 13\n", + " 2 1.00 1.00 1.00 10\n", + "\n", + "avg / total 1.00 1.00 1.00 30\n", + "\n", + "\n", + "accuracy : 1.0\n" + ] + } + ], + "source": [ + "y_true, y_pred = y_test, clf.predict(X_test)\n", + "\n", + "print(classification_report(y_true, y_pred))\n", + "print()\n", + "print(\"accuracy : \"+ str(accuracy_score(y_true, y_pred)) )" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ground_truthprediction
011
111
222
300
411
522
622
722
811
900
1011
1100
1200
1311
1411
1522
1611
1722
1811
1922
2000
2111
2200
2300
2411
2511
2622
2711
2822
2922
\n", + "
" + ], + "text/plain": [ + " ground_truth prediction\n", + "0 1 1\n", + "1 1 1\n", + "2 2 2\n", + "3 0 0\n", + "4 1 1\n", + "5 2 2\n", + "6 2 2\n", + "7 2 2\n", + "8 1 1\n", + "9 0 0\n", + "10 1 1\n", + "11 0 0\n", + "12 0 0\n", + "13 1 1\n", + "14 1 1\n", + "15 2 2\n", + "16 1 1\n", + "17 2 2\n", + "18 1 1\n", + "19 2 2\n", + "20 0 0\n", + "21 1 1\n", + "22 0 0\n", + "23 0 0\n", + "24 1 1\n", + "25 1 1\n", + "26 2 2\n", + "27 1 1\n", + "28 2 2\n", + "29 2 2" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Visualize true value with prediction value in pandas dataframe.\n", + "comparison = pd.DataFrame({'prediction':y_pred, 'ground_truth':y_true}) \n", + "comparison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}