{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 6-量子アニーリングによる機械学習 (QBoost)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/OpenJij/OpenJijTutorial/blob/master/source/ja/006-Machine_Learning_by_QA.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "このチュートリアルでは、量子アニーリングの最適化の応用の一例として機械学習を取り上げます。 \n", "前半では、PyQUBOとOpenjijを利用したクラスタリングを行います。 \n", "後半では、PyQUBOとD-Waveのサンプラーを用いたQboostというアンサンブル学習を行います。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## クラスタリング\n", "\n", "クラスタリングとは与えられたデータを$n$個のクラスターに分けるというタスクです($n$は外部から与えられているとします)。簡単のため、今回はクラスター数が2を考えましょう。\n", "\n", "### 必要なライブラリのインポート\n", "\n", "機械学習のライブラリであるscikit-learnなどをインポートします。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# ライブラリのインポート\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from sklearn import cluster\n", "import pandas as pd\n", "from scipy.spatial import distance_matrix \n", "from pyqubo import Array, Constraint, Placeholder, solve_qubo\n", "import openjij as oj\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 人工データの生成\n", "\n", "今回は人工的に二次元平面上の明らかに線形分離可能なデータを生成しましょう。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = []\n", "label = []\n", "for i in range(100):\n", " # [0, 1]の乱数を生成\n", " p = np.random.uniform(0, 1)\n", " # ある条件を満たすときをクラス1、満たさないときを-1\n", " cls =1 if p>0.5 else -1\n", " # ある座標を中心とする正規分布に従う乱数を作成\n", " data.append(np.random.normal(0, 0.5, 2) + np.array([cls, cls]))\n", " label.append(cls)\n", "# DataFrameとして整形\n", "df1 = pd.DataFrame(data, columns=[\"x\", \"y\"], index=range(len(data)))\n", "df1[\"label\"] = label" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEGCAYAAABsLkJ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZmElEQVR4nO3df7CcVX3H8c939/5ITBDShEpJiNEGbIFBGG5VvDOOgm2pYhyl/gbGapt/wMEZp4ktdaxDWxuc+kcLHXqLjmOLOoyRhhEdxQaHQoHhhglpCKKpM8iNv0IaYoLJzb27p3/kLuzduz+e3X2e55znOe/XX9zde3fPLnC+53y/54c55wQAiE/FdwMAAH4QAAAgUgQAAIgUAQAAIkUAAIBIjfhuQD/WrFnjNmzY4LsZAFAou3btes45d2br44UKABs2bND09LTvZgBAoZjZM+0eJwUEAJEiAABApAgAABApAgAARIoAAACRIgAAKIxDx2b1xLPP69CxWd9NKYVCLQMFEK8duw9o6/Y9Gq1UNFev65arL9Kmi9f6blahMQMAELxDx2a1dfsenZir6+jsvE7M1bVl+x5mAkMiAAAI3szh4xqtLO6uRisVzRw+7qlF5UAAABC8dauWa65eX/TYXL2udauWe2pRORAAAARv9cpx3XL1RVo2WtFp4yNaNlrRLVdfpNUrx303rdAoAgMohE0Xr9XkxjWaOXxc61Ytp/NPAQEAQGGsXjlOx58iUkAAECkCAABEigAAAJEiAABApAgAABApAgAARIoAAACRIgAAQKQIAAAQKQIAAESKAAAAkSIAAECkCAAAECkCAABEigAAAJEiAABApAgAABApAgAARIoAAACR8hYAzOwcM7vfzPaZ2ZNmdqOvtgCxO3RsVk88+7wOHZv13ZShlOVz5MXnpfDzkj7hnHvczE6TtMvM7nPO7fPYJiA6O3Yf0NbtezRaqWiuXtctV1+kTRev9d2svpXlc+TJ2wzAOfcz59zjC/98VNJTkvi3BeTo0LFZbd2+Ryfm6jo6O68Tc3Vt2b6ncCPosnyOvAVRAzCzDZIukfRom+c2m9m0mU0fPHgw97YBZTZz+LhGK4u7gdFKRTOHj2fyflmlaPL+HGXhMwUkSTKzlZK2S/q4c+5Xrc8756YkTUnSxMSEy7l5QKmtW7Vcc/X6osfm6nWtW7U89ffKMkWT5+coE68zADMb1anO/07n3Dd8tgWI0eqV47rl6ou0bLSi08ZHtGy0oluuvkirV46n+j5Zp2jy+hxl420GYGYm6QuSnnLOfd5XO4DYbbp4rSY3rtHM4eNat2p5Jp1mI0VzQi+N0hspmrTeL4/PUTY+U0CTkq6V9D9mtnvhsb90zn3LX5OAOK1eOZ5ph5lXiibrz1E23gKAc+5BSebr/QHkp5Gi2dJSA6Cz9st7ERhAHEjRhIcAACA3pGjCEsQ+ACBrHBEALMUMAKUX2xEBh47NkmZBIgQAlFrz+vPGEsQt2/docuOaUnaOsQU7DIcUEEotpiMCOA8H/SIAoNRiOiIgpmCHdBAAUGoxHREQU7BDOqgBoPRiWX/OZiv0iwCAKMSy/jyWYId0EACAkokl2GF41AAARGeQjYFl3EzIDABAVAbZK1HW/RXMAABEY5C9EmXeX0EAABCNQfZKlHl/BQEA8KiMeeWQDbJXosz7KwgAgCc7dh/Q5LaduuaORzW5bafu2X3Ad5NKb5CNgWXeTGjOOd9tSGxiYsJNT0/7bgYwtEPHZjW5badOzL00slw2WtFDWy8vRccSukFOTC3yKatmtss5N9H6OKuAAA/yuCQdnQ2yV6KM+ytIAQEelDmvjOIgAAAehJZX7lWMDqVYHUo7yoIUEOBJFuf2DJKn7rXJKZRNUKG0o0woAgMlMUgH2asYHUqxOpR2FFWnIjApIKAEBt2t2muTUz+boLJMz5R5M5ZPpICAEhh0VVGvYnTSYnXW6RmK5tlgBgAkFHIBctAOslcxOkmxOo+zckIrmpcFMwAggdALkMPcBtarGN3r+bz2NHDZTfoIAEAPzSPcRie3ZfseTW5cE1QnNEwH2WuTU7fns0jPdFrNlMVmrCLv8B0WAQDooUi7dn3sVk37LuI8Z1uhz+yyRgAAevBdgCzCCDWt9Eyes62izOyyRAAAekh7hNuPIo1Q05h95DnbKtLMLisEACABHwXIEEeoWc5GDh2b1ZHjczpZy2e25XtmFwICAJBQ3vn10EaoWc5Gml+7Vq9rtGpaNlLNdLbVOrM7Wavp+jdvTP19QuZ1H4CZfdHMfmlme322AwhRSCPUTmv99//i6NB7I1pfe74uVUy67UOX6KGtl2ea8tp08Vo9tPVy/dmbXi3JNPXAj6O6nMf3RrAvSbrScxuAtnxv/App81O7oxgk6W3/+F9D32jW7rXHqlWdvnwst8/6z9/fr9n58l363ovXFJBz7gEz2+CzDUA7oRRf+609ZJWjbzcbaRzMdrI2L2nw+kS/M51BP2Onvwst1Zan4GsAZrZZ0mZJWr9+vefWIAaDFF+zLI4mrT20Bq1PXXW+Ljz79FTa1Jovn63VZc5ptvbSacKDdpr9rLIaNDB3+7uQUm15Cz4AOOemJE1Jp46D9twcRKDfEWEIs4V2Qeumu/dqxVhVNedSaVPzbGTFWFVX3fqg1BQAhuk0k8x0Bl0V1evv2gWgT739/BdPGi3zLCD4AADkrZ8RYShLNdsFLUl64WQt1TY1z0bS3hvRa6YzaKomyd81B6C9B47o5nv3eU//5YEAALToJyURSv64XdBqVobD2QZN1ST9u0b73zf1sPeAnhffy0C/KulhSa8xsxkz+6jP9gANjeWB//6nr++6FDGU/HHziqEV49Ulz2fVptUrx/Xac87IpXMcdFVUP38X28UzvlcBfcDn+wPdJCm++jwmolWvNEboI9gkhfRBZx1J/y6UgJ4X7gQGUpD3gW1J3q8Ih8g1ZFFI7/fzN36/jDWATncCUwMAukjaieR5TETSzrLfNvkKGFkU0vsNKFkuoU1DVv9uCABAByEs72yV1aojn5817UJ6v99Ru9+/+Zv79NDWy4Po/LP8d+P7KAggSHncczuILIqUST5r87EYaR+RkXbevd/vKOTCb9b/HTIDANrod1SaV/okiyJlr8/aPAI9MV+Tc07LR0dSG42mXUjv9zsKufCb9TJjAgDQRj+dQp7pkyxWHXX7rO3SI5J0dHa4839apbmnoN/vKKSVXK2yDk4EAKCNpJ2Cj53AaW/A6vZZn3j2ebl655WCaY5G0yyk9/sd+bjwJ4msgxMBAOggSafgaydw2quOOn3WFWPVRQe+tepnNJr3KqN+v6O8L/xJKsvgRAAAuujVKYScP+5Xu8/6wsmalo1WXjz6ueFlo1XV5RKPRkNcUVUkWQUnVgEBQwjp0hZp6SU2w67YaRfIxkcquv3aSxPf1hXqiiowAwCGFkr+uHWU/d6JdbpremaoUXenHPSbzjsz8WuEcmAeliIAACnwnT9uV4z+8sM/kaShi9PDBrgypcnKhhQQUAKd7uxtNszmpmFO/QwtTYaXMAMASqDXfQCS31F3KGkyLMYMACiBdqPs6y5bH9Sou9MsIu2jJZAcMwCgJNqNsm+84rygR90sD/WLAACUSGsx2ndxuptQ7lOOGSkgAF6EfApnLAgAALxgeah/BAAAXrA81D9qAAA6yvoAN5aH+kUAANBWXit0Qi5Ulx0pIABLcIBbHHoGADP7mJmtyqMxAMLACp04JJkBvELSY2Z2l5ldaWaWdaMA+MUKnTj0DADOub+SdK6kL0j6sKQfmdnfmdlvZ9w2AJ6wQicOiYrAzjlnZj+X9HNJ85JWSfq6md3nnNuSZQMB+MEKnfLrGQDM7EZJ10l6TtIdkv7cOTdnZhVJP5JEAABKihU65ZZkBvAbkt7tnHum+UHnXN3MrsqmWQDylvel7fCvZwBwzn26y3NPpdscAD5wKmec2AcARI41//EiAAAlMMylKqz5jxdHQQAFN2z6hjX/8fI6A1jYWPa0me03s0/6bAtQRGmkb1jzHy9vMwAzq0q6TdLvS5rRqd3G9zjn9vlqE1A0jfRN40Yt6aX0TT8dOGv+4+QzBfQ6Sfudcz+WJDP7mqR3SiIAAAmlmb5hzX98fKaA1kp6tunnmYXHFjGzzWY2bWbTBw8ezK1xQBGQvsEwgi8CO+emJE1J0sTEhPPcHCA4pG8wKJ8B4ICkc5p+XrfwGIA+kb7BIHymgB6TdK6ZvcrMxiS9X9I9HtsDAFHxNgNwzs2b2Q2SviOpKumLzrknfbUH6Bdn56DovNYAnHPfkvQtn20ABsHZOSgDjoIA+sTZOSgLAgDQJ87OQVkQAIA+cXYOyoIAEIhhTnNEvth8hbIIfiNYDCgoFg+br1AGzAA8K2pBkRnLqZnAa885g84fhcUMwLO0TnPMEzMWoByYAaRgmNFw0QqKRZ2xAFiKADCkHbsPaHLbTl1zx6Oa3LZT9+zu7zijohUUWQIJlAcpoCE0j4YbKZwt2/docuOa0l7GUbQZC4DOmAEMIc3RcFEKikWbsaSJwjfKhhnAEPodDYd6eFi/7SrSjCUtFL5RRgSAITRGw1taOoZ2HWKoHcig7Yrp/Pm0Un1AaAgAQ0oyGs6qAxl2RkHHlkwRl+oCSRAAUtBrNJxFB5LGjIKOLRkK3ygrisA5GKQD6VZwTGstPh1bMjEXvlFuzABy0E+tQOo9uu82cm88nyQt1G+7Qi1i5yHGwjfKjwCQk6QdSJK8fKeR+94DR/S+qYf7SgslbVfeRewQg01MhW/EgRRQjpKs9U+yt6BdSuJTbz9fN9+7b6C0UK925X38w52PPKPLPvuf+uC/PjLQ7moAyTADCEzSvHzryD3Lgm6exeI7H3lGN/3HXknSyVpNEiuTgKwwAwhMPwXH5pF7lgXdvIrFh47N6jPf3Lfk8apZYc8aYvcwQsYMIECDFByTFHQ75dV75dv7LRYPaubwcY1VTSfnFz8+VyvmyqRQN/8BDQSAQA1ScOwWODp1Rkk7qTxWwaxbtVzzdbfk8U+/44LCpX/YZIciIAVUMu0Kup2KuPt/cTRRcbeRxpCU6YF1zemvFeNVjY1U9LfvulAfesMrM3m/LHFsNoqAGUAEOhVxdz/7fM/ibvMM4WStphvecq4++Pr1mQWBsqy3Z5MdioAZQAQ6dUYXn3NG106qdeYwO+/0D/f9UG/8+2yXZhblaOxu2D2MImAGEJgsNkB1KuJufMVpXYu77WYOkjQ7Xy9NPjvLDWdlmc2gvAgAAcly1cjkxjWaunZCktMFZ5/+YmfUrZNqN3NoKMOhcXms0mH3MEJGCigQWe62bdxbfP2dj2vzv+3SQ/ufW/R8c8qled16Y+YwPmJLXrPo+WwutweYAQQjq922/SxH7DQinty4Rl959Ce69f79Gqtmtw8gTxyFDRAAgpHVqpGkHV2vQPGxK06t/gkxnz1IHp9VOgApoGBktWokaUeX9BC6LFbnDHNcQiO9dc0dj/Z1cByrdABmAEHJYtVI0mMcfI2IhynEDrvbllU6iJ2XAGBm75H015J+V9LrnHPTPtoRoixWjbTr6FrTJnmd99Ns2A48jTw+q3QQM18zgL2S3i3pXzy9f3SaO7pOo+68R8Qzh49rpLJ4hVG1Yok7cPL4wHC81ACcc08555728d55CPkI4F7LH/Pchbv3wBEdm60teuyF2Zr2HjiS6O/J4wPDCb4GYGabJW2WpPXr13tuTW+hHwH85E+PqGKLR90+lj8eOjarm+9deva/JN187z5deeFZ5PGBjGUWAMzse5LOavPUTc65HUlfxzk3JWlKkiYmJpaeFRyQ0I8A3rH7gLZ8/QnNzi/+GtNKmyRZjtn4nSPH59oeMyGRxwfyklkAcM69NavXDlWam4vSPqOmEZxaO//xkXTSJklmPotPFq2r1uGYCfL4QD6CTwEVSVpFySzSSO2C08tGq7r92kv1pvPOHOq1k8x82v3OaNU0PiLJSbM1p2Wjp0pS5PGBfPhaBvouSf8k6UxJ95rZbufcH/poS5rSWEqZVRqpXXCqy+mCs18+8Gs2JJn5tPudZSNV3fahS3T68jGtGKvqhZO1oWY8WZ7sCZSRlwDgnLtb0t15vV+eHcOwRcmszqjJcp1/kplPp99pPpl0GKEX34EQlT4F5KNjGKYomeXa9qxWzCQJLlkGoNCL70CoSh0AitgxZL0jN40VM+1mVEmCS1YBiJM9gcGUOgAUtWMIeW17txlVkuCSxZJNdgQDgyn1aaBF7hjyvhc3ye7lUC9RYUcwMJhSzwB8HHBWREnrJM1HQ7c+7vs7DXnWBISq1AFAomPopZ86yYqxqk7MLZ5RnZira8VYNbf2dsOOYKA/pQ8AUjE7htZCa1ZLWfupk7xwsqbxqmm29tJu4vGq6YWTiw90A1AMUQSAomlNybz30nW6a9dMoqWszYFCUs+g0U+dZN2q5bKKSU0BwCpWiJoKgKUIAAE5dGxWT/70yIsHtjVG5V9+5CeS1Nel7ifma3LOafnoSNeg0U+dhJoKUC4EgEA0Ou+K2ZID21olvdRdko7Ozkvqvv+hnzoJNRWgPAgAAWjuvJPodql7u+OVpd77H/qpkxSxpgJgqVLvAyiKRufd6mWjVS0brei6y9b3XOPeLpffrCj7HwDkhxlAANp13uMjFd1+7aW64OyXa/XKcd14xXld0y6t+fl2NQBG7QCaEQAC0Km42nxOf5K0S2t+Xuq9CghAvAgAgUiruNoaKOj4AXRCAAiIr+IqF6kAcSIARKi5w39w/3NcpAJEigAQmcUXs9dUd9JczRXmvgQA6WEZaERaj3OenXeaqy3edNbYLwCg/AgAAUpyNv8gOu03aMZ+ASAepIBSklYhNcs7jNvtNxipSNVKRWNVzvYBYkMASEFanXbWdxh32m/g82wfViAB/hAAhpRmp53HHcad9hv46HyznO0A6I0awBAOHZvV/T/4papmix4ftJDa6Wz+FWPVVGsCed833E6o9wsDMWEGMKDG6HWksvRGrEELqe1SNO+dWKerbn2wdKPkPGY7ALojAAyg0/HNK8aqqjk3VCG1OUWzYqyqq259MLOagE/93EQGIBukgAbQbjnlivGqPrPpAj209fKhR+iNFM0LJ2tL3if0dfpJl7A2Zju9jrkGkB1mAANoN3qt1Z3e8ju/mWoHVrRRcq+ibuuKH24XA/xiBjCAPEev1795o8ZHwh8l9yrq7th9QJPbduqaOx7V5Ladumf3AUlhFKSBWDEDGFDWo9fm0bTktPlNr9YHX78+2I6yW1FXUqb7GwAMhhnAELIavbY7s+e27+9P9T3S1i1d1a5mEnotA4gBASBARewwu6XFilbLAGJBCihARe0wu+0ybncEBekfwC8vAcDMPifpHZJOSvpfSX/inHveR1tCVOQOs9OtZqz4AcJjzrnev5X2m5r9gaSdzrl5M9smSc65rb3+bmJiwk1PT2fevlBwUBqANJjZLufcROvjXmYAzrnvNv34iKQ/9tGO0Pm6IxhAHEIoAn9E0rc7PWlmm81s2symDx48mGOziimry2QAlE9mMwAz+56ks9o8dZNzbsfC79wkaV7SnZ1exzk3JWlKOpUCyqCppcHxygD6kVkAcM69tdvzZvZhSVdJusL5KESUTNaXyQAoHy8pIDO7UtIWSZucc7/20YayKeLeAQB++aoB3CrpNEn3mdluM7vdUztKo6h7BwD442sV0EYf71tmRd47AMAPdgKXCJutAPSDAFAy7B0AkFQI+wAAAB4QAAAgUgQAAIgUAQAAIkUAAIBIeTkOelBmdlDSMym+5BpJz6X4emXB99Ie38tSfCfthfa9vNI5d2brg4UKAGkzs+l2Z2THju+lPb6XpfhO2ivK90IKCAAiRQAAgEjFHgCmfDcgUHwv7fG9LMV30l4hvpeoawAAELPYZwAAEC0CAABEKvoAYGafM7MfmNkeM7vbzM7w3aYQmNl7zOxJM6ubWfDL2bJkZlea2dNmtt/MPum7PSEwsy+a2S/NbK/vtoTEzM4xs/vNbN/C/z83+m5TN9EHAEn3SbrQOXeRpB9K+gvP7QnFXknvlvSA74b4ZGZVSbdJ+iNJ50v6gJmd77dVQfiSpCt9NyJA85I+4Zw7X9IbJF0f8n8v0QcA59x3nXPzCz8+Immdz/aEwjn3lHPuad/tCMDrJO13zv3YOXdS0tckvdNzm7xzzj0g6f98tyM0zrmfOeceX/jno5KekrTWb6s6iz4AtPiIpG/7bgSCslbSs00/zyjg/6ERDjPbIOkSSY96bkpHUdwIZmbfk3RWm6ducs7tWPidm3Rq+nZnnm3zKcn3AqB/ZrZS0nZJH3fO/cp3ezqJIgA4597a7Xkz+7CkqyRd4SLaGNHre4Ek6YCkc5p+XrfwGNCWmY3qVOd/p3PuG77b0030KSAzu1LSFkmbnHO/9t0eBOcxSeea2avMbEzS+yXd47lNCJSZmaQvSHrKOfd53+3pJfoAIOlWSadJus/MdpvZ7b4bFAIze5eZzUi6TNK9ZvYd323yYWGBwA2SvqNTBb27nHNP+m2Vf2b2VUkPS3qNmc2Y2Ud9tykQk5KulXT5Qn+y28ze5rtRnXAUBABEihkAAESKAAAAkSIAAECkCAAAECkCAABEigAAAJEiAABApAgAwBDM7PcW7pJYZmYrFs6Av9B3u4Ak2AgGDMnM/kbSMknLJc045z7ruUlAIgQAYEgLZwQ9JumEpDc652qemwQkQgoIGN5qSSt16kypZZ7bAiTGDAAYkpndo1M3hb1K0m85527w3CQgkSjuAwCyYmbXSZpzzn1l4f7g/zazy51zO323DeiFGQAARIoaAABEigAAAJEiAABApAgAABApAgAARIoAAACRIgAAQKT+HwcTvkm8///xAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# データセットの可視化\n", "df1.plot(kind='scatter', x=\"x\", y=\"y\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "今回は、以下のハミルトニアンを最小化することでクラスタリングを行います。\n", "\n", "$$\n", "H = - \\sum_{i, j} \\frac{1}{2}d_{i,j} (1 - \\sigma _i \\sigma_j)\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$i, j$はサンプルの番号、$d_{i,j}$は2つのサンプル間の距離、$\\sigma_i=\\{-1,1\\}$は2つのクラスターのどちらかに属しているかを表すスピン変数です。 \n", "このハミルトニアンの和の各項は \n", "\n", "- $\\sigma_i = \\sigma_j $のとき、0\n", "- $\\sigma_i \\neq \\sigma_j $のとき、$d_{i,j}$ \n", "\n", "となります。右辺のマイナスに注意すると、ハミルトニアン全体では「異なるクラスに属しているサンプル同士の距離を最大にする$\\{\\sigma _1, \\sigma _2 \\ldots \\}$の組を選びなさい」という問題に帰着することがわかります。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyQUBOによるクラスタリング\n", "\n", "まずは、PyQUBOで上のハミルトニアンを定式化します。そして`solve_qubo`を用いてシミュレーテッドアニーリング(SA)を行います。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def clustering_pyqubo(df):\n", " # 距離行列の作成\n", " d_ij = distance_matrix(df, df)\n", " # スピン変数の設定\n", " spin = Array.create(\"spin\", shape= len(df), vartype=\"SPIN\")\n", " # 全ハミルトニアンの設定\n", " H = - 0.5* sum(\n", " [d_ij[i,j]* (1 - spin[i]* spin[j]) for i in range(len(df)) for j in range(len(df))]\n", " )\n", " # コンパイル\n", " model = H.compile()\n", " # QUBOに変換\n", " qubo, offset = model.to_qubo()\n", " # SAで解を求める\n", " raw_solution = solve_qubo(qubo, num_reads=10)\n", " # 解を見やすい形にデコード\n", " decoded_sample = model.decode_sample(raw_solution, vartype=\"SPIN\")\n", " # ラベルを抽出\n", " labels = [decoded_sample.array(\"spin\", idx) for idx in range(len(df))]\n", " return labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "実行および解の確認を行います。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":15: DeprecationWarning: Call to deprecated function (or staticmethod) solve_qubo. (You should use simulated annealing sampler of dwave-neal directly.) -- Deprecated since version 0.4.0.\n", " raw_solution = solve_qubo(qubo, num_reads=10)\n" ] } ], "source": [ "labels =clustering_pyqubo(df1[[\"x\", \"y\"]])\n", "print(\"label\", labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可視化をしてみましょう。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVwUlEQVR4nO3dX4glZ5nH8d8zk3jRJCCZGRCSzGlhRTYEQaYJK7uwsHoxhmWDsoLSWSIRZp1RUBBEd272pq8Er6KGBkWxD0pAJYKROAEhu6BijwR3khgJkhlHFuzMXLhhhDAzz15UH+b06apz6s9bVe9b9f3AYdLV3ee852g/9dbzPO9b5u4CAKTrSN8DAAA0QyAHgMQRyAEgcQRyAEgcgRwAEndXHy96/PhxX19f7+OlASBZFy9efNPdTywe7yWQr6+va3d3t4+XBoBkmdnlvOOkVgAgcQRyAEgcgRwAEkcgB4DEEcgBIHEEcgDdmk6l9XXpyJHs3+m07xElr5f2QwAjNZ1KZ85IN25kX1++nH0tSZub/Y0rcczIAXTn/Pk7QXzmxo3sOGojkAPozpUr1Y6jFAI5gO6cPFntOEohkAPoztaWtLZ28NjaWnYctRHIAXRnc1Pa3pYmE8ks+3d7m0JnQ3StAOjW5iaBOzBm5ACQOAI5ACSOQA4AiSOQA0DiCOQAkDgCOQAkjkAOAIkjkANA4gjkAJA4AjkAJI5ADgCJI5ADQOII5ACQOAI5ACSOQA4AiSOQA0DiCOQAkDgCOQAkjkAOAIlrHMjN7EEz+7mZvWJmL5vZ50IMDEBF06m0vi4dOZL9O532PaJ6hvI+OhTi5ss3JX3B3X9jZvdKumhmF9z9lQDPDaCM6VQ6c0a6cSP7+vLl7GsprRsdD+V9dMzcPewTmj0r6Sl3v1D0MxsbG767uxv0dYFRW1/Pgt6iyUR6442uR1PfUN5HS8zsortvLB4PmiM3s3VJ75f0q5zvnTGzXTPb3dvbC/myAK5cqXa8iTZTH12+jwEJFsjN7B5JP5D0eXf/y+L33X3b3TfcfePEiROhXhaAJJ08We14XbPUx+XLkvud1EeoYN7V+xiYIIHczO5WFsSn7v7DEM8JoIKtLWlt7eCxtbXseEjnz9/JX8/cuJEdD6Gr9zEwIbpWTNI3Jb3q7l9tPiQAlW1uStvbWS7ZLPt3ezt8gbDt1EdX72NgGhc7zewfJP2XpP+RdHv/8H+4+3NFv0OxE0gUxcheFRU7G7cfuvt/S7KmzwMgAVtbB9sDJVIfEWBlJ4DySH1EKcSCIABjsrlJ4I4MM3Kkg6XbQC4COdLQdv9ybDhpoQICOdLQdv9yTMZ20kJjBHKkYUxLt8d00kIQBHKkYUxLt8d00kIQBHKkYUxLt8d00kIQBHKkYUz9y2M6aSEI+siRjrH0L8/e4/nzWTrl5MksiI/hvaMWAjkQo7GctBAEqRUAaarTaz/Q/nxm5ADSU+fengO+H2jwe3aWwTa2ABqps53uALbg7eSenQDQiTq99gPuzyeQA00NNO8atTq99gPuzyeQA02wL0o/6vTaD7g/n0AONMG+KP2os0BswIvKKHYCTRw5ks3EF5lJt28fPg40QLETaMOA865IB4EcaCKmvOuqomssRdlYxjEk7t7549SpUw4Mxs6O+2Tibpb9u7PT/fPt7LivrblniZ7ssbZ253dXfb8rsYwjUZJ2PSemkiMHYrK4+lDKZvirinKrFrvEshgmlnEkihw5kIK6XTCrFruUXQzTdtpjwIty+kQgB2JSN9CtKrqWKcp20RNPcbgVBHKMS+yFtrqBblXRtUxRtoue+JiKw0OSlzhv+0GxE71IodDWZIyriqSrvm928HVnD7Mw763sOFBIBcVOAjnGYzLJD1STSd8jO6ivQBf68+nyfYzk5EAgB7qacaYq5BVLl1c/KVxpBVIUyMmRYzz6LLTFnpuXwu5F0uUeNOx3Qx85RqRuj3aqr9unLvegGdF+N/SRA33tfhfbjLHNq4PZcxdNENu4+qGlkUCOkdnczFYQ3r6d/dvFjDimRTBt9orPP3eettoM81oaJemtt+JMYbUgSCA3s2+Z2Z/N7FKI5wMGJaYZY9HVwRNPNJ+h5z33TJtXP7MrrWPHDh6/dm00N/kINSP/tqTTgZ4LCCeGImNMi2CKrgJu3Wo+Qy96brP2r342N6V77jl8fCRFzyCB3N1flHQ9xHMBwcRyG7aqufk2Tz5lrgLqBr8qVx513+Oy34sphdW1vJ7EOg9J65IuLfn+GUm7knZPnjzZfsMlUHWBSwyLSvJ6omf976G2yF18/lC99WX7uev2fa/6vVQWfDWgthcErQrk8w8WBKETVRYAxbKopCgYhRzT/Anr6NGwwa/MybBuwF31e22fBCNAIMf4VAkYsczmik4+bY2pjxNY3RW2ZX5vdiKZD+IDWu1ZFMhpP8RwVSkyxpJfLZPDDjmmPnrr63bxlPm9WXvpZHK4l33Ahc9Q7Yffk/QLSe81s6tm9qkQzws0UiVIxdIiWNQTPS/0mLrura/bxZPiibkredP0th+kVhCdWHLks7Gkmh4oWzCuW1gu+3uxpMoCE7sfAit03bVS5vVi6KQpq42TYdX3n/JJsAQCOcYrxmDY1hVAn++1jf3Mq3xGsXetBPjfhkCOcYopZTKvjUv/vt9r6P3eq35GMadTAv1vUxTI2cYWw7a+nr+J02SSFfb60sbWq6ve63SadW1cuSLdd1/2vevXs+Lp1lbzImfoz7rqZxTzdraBPhu2scU4Vele6HJflja6ZJa918XtCq5dyx7u4bYuCL2nTNXPKJbOozxtd9HkTdPbfpBaQWfKXm53nZZo4/WWvddVK0ZDpSBC5uhD5MhjSKO5B0v7iBw5RqnsH3cf+dXQhcll73VVEG+Sy25T3a6VmArb7q3nyAnkGL4yf9xDuTFz0Xst2lOl6kkr1kCZgha7Vih2AlK8RdFQzJZ/v8w9RMd479HIUOwElonp5g+LRddz55oXYSeT5d8rE4xju/co7sibprf9ILWCKMWQNiizX3idAl6IHO1Q0k8JE6kVIAFFKZ5FdVI+833kdXrHh55+SgCpFSAFZfuK6/QfN93lMKb0Ew4gkAMxKbt4pY9FLn3sXY5SCORATMrsR97nLLhoVt/lqlgcQiAHYpI36z17Nu5Z8OLy/1BL/lEaxU4AzVAE7QzFTgDtGNtt1SJEIAfQTMy7Do4EgRxAM7Ql9o5ADgxd2x0ltCX27q6+BwCgRYsbXc06SqSwgXZzk8DdI2bkwJCx0dUoEMiBIaOjZBQI5MCQ0VEyCgRyYMjoKBkFAjkwZHSUjAJdK8DQ0VEyeMzIgSFhF8JRYkYODEVXPeOIDjNyYCjoGR8tAjkQi6ZpEXrGR4tADsQgxM0Z6BkfrSCB3MxOm9lrZva6mX0pxHMCoxIiLULP+Gg1DuRmdlTS1yR9WNJDkj5hZg81fV5gVEKkRegZH60QXSuPSHrd3f8gSWb2fUmPSXolwHMD43DyZP7t0qqmRegZH6UQqZX7Jf1x7uur+8cOMLMzZrZrZrt7e3sBXhYYENIiaKCzYqe7b7v7hrtvnDhxoquXBdJAWgQNhEit/EnSg3NfP7B/DEAVpEVQU4gZ+a8lvcfM3m1m75D0cUk/DvC8AIASGgdyd78p6bOSnpf0qqRn3P3lps8LdIK9STAAQfZacffnJD0X4rmAzrA3CQaClZ0YL/YmwUAQyDFe7E2CgSCQY7zYmwQDQSAPiLpZYliEg4EgkAcSYvM6dIxFOBgIAnkgKdbNuIJQFrTfeEO6fTv7lyCOBHGrt0BSq5vReQcMBzPyfU1np6nVzVK8ggCQj0CuMPnt1OpmqV1BAChGIFeY2WlqdbPUriAAFCOQK9zsNKW6WWpXEKFQ4MUQEchVbXYaayCoOq7UriBCoEUUg+XunT9OnTrlMdnZcV9bc8/+vLPH2lp2vM7PdS3WccVmMjn4Gc0ek0nfIwPKkbTrOTGVGbnKz07b6PQIMcOnA6UcCrwYKsuCfLc2NjZ8d3e389dt6siRbA63yCzLi1e12MstZXnqqimO0OMaqvX1/PsbTyZZTQOInZlddPeNxePMyCuo2umxarYdaiZNB0o5Yy3wYvgI5BVUCQRlCmtFl/SXL1dLt1QdV4zF2i6MscCLkchLnLf9iK3YWcXOTlYcM8v+LSoolimsFf2MWfXCZZlxdV0ULftZAShHBcVOcuQtKZO3zsuRm+X/Xog8bpc54nPnpKefPvhe6uT/AdxBjrxjZfLWeZf6RefVEJ0VXXVtTKeHg7hEJw3QFgJ5S8rmrRdXg04m+c8XonDZVVH0/Pl2T0hdG3NdAWkgkLekbmFt1QmgKKiUCTZddW0sC9apddKwGhRJyEuct/1IudjZhaIiYVGx8uzZ8kXMLgqQy4q4qRU8WQ2KmIhiZ/qKipVHj0q3bh0+Pl/EnE6zlMeVK9mseGurvaJjURH305+Wvv71dl6zLSy2Qkwodg5AUcoiL4jP/3xeeuDxx6Xjx9tJEeSllb773fSCuMRiK6SBQJ6QouBx9Ojyn89bQSpJ1661l+9NaUvfZVgNihQQyFvQVpdDUVA5c2Z5sFlWfBxCS2CbXSWsBkUS8hLnbT+GXOxsc/Xkzo77sWN3nvfYsYOF0KIiZlHBbr4ImSq28MWYiG1su9HWlrKzPPe1a3eO/fWvd/57PpWxtZW93myG+uijh2fs81LO97KFL8A2tsG11eVQdnl90da4TzwhPfPMwRPB7HsppwroKsGY0LXSkba6HMoury+aoT73nPTmm9LOTpz53rp5brpKAAJ5cG11OZQNWKsCflvdJE0Kjk1WT9JVAohiZxvaWD1ZtqjXx0rEpgXHpmNmu1yMhdpY2WlmH5P0n5L+VtIj7l4q8T3kHHmbFldnPvpoljKZX60phbl9XBVNt8clzw2U01aO/JKkj0p6seHzoITFzpTvfOdwOkLqvu85L4gvO76IPDfQTKNA7u6vuvtroQYTk9i3Ll3WdtflqsrpNDth5DEjzw10obNip5mdMbNdM9vd29vr6mVriX3r0um0eLbb9X7fy/Yedy/Xz83qSaCZlYHczF4ws0s5j8eqvJC7b7v7hrtvnDhxov6IOxDzIpPZSaZI03REmSuR+Z9ZlT4pe2IZyt4sQB9WBnJ3/5C7P5zzeLaLAfYh1C3R2kjPFG2AJTVPR5S5Eln8mVXIcwPto488R4jiW1vpmWUnk6bpiDJXIstOJIvIcwPdaBTIzewjZnZV0gck/cTMng8zrH6FKL61lZ4pOplMJs3TEWWuRFZdlcy21K2b5469yAxEKa+5vO1HnQVBXS/6aPp6Zu3sNNjmbn9lFua0ueCInQyB5VSwICiJQJ7iH3jbAa+Nk1qZz7nvEwkwZkkH8hT/wGM/+Sy7wfOqk0RbJ5K2rmKAoSgK5ElsY5vqEu4ub3hcdVxdL+Mvo+lSf2Dokt7GNtUl3F33RpctFMbaJ88KT6CeJAI5f+CrVWl3bLo3SltY4QnUk0RqRYo3TRGLKmmJu+6Sbt06/LNHj0o3b7YxOgAhJJ1akdJcwr2Y6jh3rr0e6SqrUfOC+LLjAOKWTCBPTV6q4xvfKJf6mD8BHD+ePVYF/yp1hMkk/2eLjgOIG4E8sFkQfvzx1UvZ8wqMiyeAa9eyx6rgX6WOQM0BGBYCeUDzQbisMjdPnlfUXVKlUEhRERiWZIqdKSgqOC6zWIws6pmfF3v/PIB2JF/sTEHVbW7z0hlleuNj758H0C0CeUDLAuxkIp09uzqdkZe/nkcuG8Ciu/oewJBsbTVf+j77uVnP/H33ZV9fv07/PIB8BPKAFoNw3cC7uUmwBlAeqZXA+lq4xA0ZgPEikCdqcdHQk0+Gv60cgDQQyBOUt2jo7bcP/kwMuxkC6AaBPEFlb4BctR0SQJoI5C1pM2ddNkDTbw6MA4F8TqjgW2Vv8DrKBGj6zYHxIJDvCxl8274DT96iobvvlo4d62fvFDpmgH6x18q+kPeL7OIeo7HcaCPW+38CQ8ReKwVms8miza7qFAyLUh9HjoSbtcZyo41Y7/8JjMmoA3mZbWfrFAyL9ku5dWt4fd5V7kwEoB2jDuSr2vjqFgwX9/s+evTwzwxl1lrlzkQA2jHqQL5s1ti0YDif+ijKi8c6a61SvORuQ0D/Rh3Ii2aNswJnqLxzSrPWVd07i0Fe4m5DQN9Gvfth0bazIWeT06n01luHj8c6a11VvJz/vGZBfnu7emcPgHBGPSNv+96Vs9nttWsHjx87Fu+sdVnxkg4VIE70kbcoZG96V5aN+cqV9vvjARSjj7wHKbbmLSteppTrB8aEQN6iFAPfsnQTHSpAnBoFcjP7ipn9zsx+a2Y/MrN3BhrXIKQa+IpWjbZdUwBQT9MZ+QVJD7v7+yT9XtKXmw9pOIYY+GLZGgDAHY3aD939Z3Nf/lLSvzYbzvBwI2UAbQuZI39S0k+LvmlmZ8xs18x29/b2Ar7s8LAtLIAqVs7IzewFSe/K+dZ5d392/2fOS7opqTDkuPu2pG0paz+sNdoRWNwWdrboRmJmDyBf4z5yM/ukpH+X9EF3L3EnyfH0kdeRYu85gG4U9ZE3ypGb2WlJX5T0j2WDOJZLsfccQL+a5sifknSvpAtm9pKZPR1gTKOWYu85gH417Vr5m1ADQaaLjbwADAsrOyMzxN5zAO0a9Ta2saL3HEAVzMgBIHEEcgBIHIEcABJHIAeAxBHIASBxvdzqzcz2JOUsRK/tuKQ3Az7fUPC55ONzOYzPJF9sn8vE3U8sHuwlkIdmZrt5+w+MHZ9LPj6Xw/hM8qXyuZBaAYDEEcgBIHFDCeTbfQ8gUnwu+fhcDuMzyZfE5zKIHDkAjNlQZuQAMFoEcgBI3GACuZl9xcx+Z2a/NbMfmdk7+x5TDMzsY2b2spndNrPo26jaZGanzew1M3vdzL7U93hiYGbfMrM/m9mlvscSEzN70Mx+bmav7P/9fK7vMS0zmEAu6YKkh939fZJ+L+nLPY8nFpckfVTSi30PpE9mdlTS1yR9WNJDkj5hZg/1O6oofFvS6b4HEaGbkr7g7g9J+jtJn4n5/y+DCeTu/jN3v7n/5S8lPdDneGLh7q+6+2t9jyMCj0h63d3/4O5vS/q+pMd6HlPv3P1FSdf7Hkds3P1/3f03+//9f5JelXR/v6MqNphAvuBJST/texCIyv2S/jj39VVF/IeJeJjZuqT3S/pVz0MplNQdgszsBUnvyvnWeXd/dv9nziu7LJp2ObY+lflcAFRnZvdI+oGkz7v7X/oeT5GkArm7f2jZ983sk5L+WdIHfUQN8qs+F0iS/iTpwbmvH9g/BuQys7uVBfGpu/+w7/EsM5jUipmdlvRFSf/i7jdW/TxG59eS3mNm7zazd0j6uKQf9zwmRMrMTNI3Jb3q7l/tezyrDCaQS3pK0r2SLpjZS2b2dN8DioGZfcTMrkr6gKSfmNnzfY+pD/uF8M9Kel5Z4eoZd3+531H1z8y+J+kXkt5rZlfN7FN9jykSfy/p3yT90348ecnMHu17UEVYog8AiRvSjBwARolADgCJI5ADQOII5ACQOAI5ACSOQA4AiSOQA0Di/h/iW8RVqG82PAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for idx, label in enumerate(labels):\n", " if label:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"b\") \n", " else:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"r\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Openjijのソルバーを用いたクラスタリング\n", "\n", "次はQUBOの定式化にPyQUBOを用いて、ソルバー部分をOpenJijにしてクラスタリングを行います。" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "def clustering_openjij(df):\n", " # 距離行列の作成\n", " d_ij = distance_matrix(df, df)\n", " # スピン変数の設定\n", " spin = Array.create(\"spin\", shape= len(df), vartype=\"SPIN\")\n", " # 全ハミルトニアンの設定\n", " H = - 0.5* sum(\n", " [d_ij[i,j]* (1 - spin[i]* spin[j]) for i in range(len(df)) for j in range(len(df))]\n", " )\n", " # コンパイル\n", " model = H.compile()\n", " # QUBOに変換\n", " qubo, offset = model.to_qubo()\n", " # OpenJijのSAをサンプラーに設定\n", " sampler = oj.SASampler(num_reads=10, num_sweeps=100)\n", " # サンプラーで解を求める\n", " response = sampler.sample_qubo(qubo)\n", " # 生データの抽出\n", " raw_solution = dict(zip(response.indices, response.states[np.argmin(response.energies)]))\n", " # 解を見やすい形にデコード\n", " decoded_sample = model.decode_sample(raw_solution, vartype=\"SPIN\")\n", " # ラベルを抽出\n", " labels = [int(decoded_sample.array(\"spin\", idx) ) for idx in range(len(df))]\n", " return labels, sum(response.energies)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "実行および解の確認を行います。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label [1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1]\n", "energy -147945.96434710253\n" ] } ], "source": [ "labels, energy =clustering_openjij(df1[[\"x\", \"y\"]])\n", "print(\"label\", labels)\n", "print(\"energy\", energy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "こちらも、可視化をしてみましょう。" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVx0lEQVR4nO3dX4glZ5nH8d8znXjRJCCZGRCSzGlhRTYEQaYJK7uwsHoxhmWDsoLSWSIRZp1RUBBEd272pq8Er6KGBkWxD0pAJYKROAEhu6Bij4g7SYwEyYwjC3ZmLtwwQpiZZy9qjnP6dNU59eetqvet+n7gMOnq7nPec7Sfeut5nvctc3cBANJ1pO8BAACaIZADQOII5ACQOAI5ACSOQA4Aiburjxc9duyYb2xs9PHSAJCsCxcuvOHuxxeP9xLINzY2tLe318dLA0CyzOxS3nFSKwCQOAI5ACSOQA4AiSOQA0DiCOQAkDgCOYBOTafSxoZ05Ej273Ta94jS10v7IYBxmk6l06el69ezry9dyr6WpK2t/saVOmbkADpz7tydID5z/Xp2HPURyAF05vLlasdRDoEcQGdOnKh2HOUQyAF0ZntbWl8/eGx9PTuO+gjkADqztSXt7EiTiWSW/buzQ6GzKbpWAHRqa4vAHRozcgBIHIEcABJHIAeAxBHIASBxBHIASByBHAASRyAHgMQRyAEgcQRyAEgcgRwAEkcgB4DEEcgBIHEEcgBIHIEcABJHIAeAxBHIASBxBHIASByBHAASRyAHgMQ1DuRm9qCZ/dTMXjazl8zsMyEGBqCa6VTa2JCOHMn+nU77HlE9Q3kfXQpx8+Ubkj7n7r8ys3slXTCz8+7+coDnBlDCdCqdPi1dv559felS9rWU1o2Oh/I+umbuHvYJzZ6V9JS7ny/6mc3NTd/b2wv6usCYbWxkQW/RZCK9/nrXo6lvKO+jLWZ2wd03F48HzZGb2Yak90r6Rc73TpvZnpnt7e/vh3xZYPQuX652vIk2Ux9dvo8hCRbIzeweSd+T9Fl3//Pi9919x9033X3z+PHjoV4WgKQTJ6odr2uW+rh0SXK/k/oIFcy7eh9DEySQm9ndyoL41N2/H+I5AZS3vS2trx88tr6eHQ/p3Lk7+euZ69ez4yF09T6GJkTXikn6uqRX3P3LzYcEoKqtLWlnJ8slm2X/7uyELxC2nfro6n0MTeNip5n9g6T/kvQ/km7dPvwf7v5c0e9Q7ATSRDGyX0XFzsbth+7+35Ks6fMAiN/29sH2QInURwxY2QmgNFIfcQqxIAjAiGxtEbhjw4wcyWDpNpCPQI4ktN2/HBtOWqiCQI4ktN2/HJOxnbTQHIEcSRjT0u0xnbQQBoEcSRjT0u0xnbQQBoEcSRjT0u0xnbQQBoEcSRhT//KYTloIgz5yJGMs/cuz93juXJZOOXEiC+JjeO+oh0AORGgsJy2EQWoFQJLq9NoPtT+fGTmA5NS5t+eQ7wca/J6dZbCNLYAm6mynO4QteDu5ZycAdKFOr/2Q+/MJ5EBDQ827xqxOr/2Q+/MJ5EAD7IvSjzq99kPuzyeQAw2wL0o/6iwQG/KiMoqdQANHjmQz8UVm0q1bh48DTVDsBFow5Lwr0kEgBxqIKe+6qugaS1E2lnEMirt3/jh58qQDQ7G76z6ZuJtl/+7udv98u7vu6+vuWaIne6yv3/ndVd/vSizjSJWkPc+JqeTIgYgsrj6Ushn+qqLcqsUusSyGiWUcqSJHDiSgbhfMqsUuZRfDtJ32GPKinD4RyIGI1A10q4quZYqyXfTEUxxuB4EcoxJ7oa1uoFtVdC1TlO2iJz6m4vCg5CXO235Q7EQfUii0NRnjqiLpqu+bHXzd2cMszHsrOw4UU0Gxk0CO0ZhM8gPVZNL3yA7qK9CF/ny6fB9jOTkQyDF6Xc04UxXyiqXLq58UrrRCKQrk5MgxGn0W2mLPzUth9yLpcg8a9rthrxWMSN0e7VRft09d7kEzpv1u6CPH6PW1+11sM8Y2rw5mz100P2zj6oeWRgI5RmZrK1tBeOtW9m8XM+KYFsG02Ss+/9x52mozzGtplKQ334wzhdWGIIHczL5hZn8ys4shng8YkphmjEVXB0880XyGnvfcM21e/cyutI4ePXj86tXx3OQj1Iz8m5JOBXouIJgYiowxLYIpugq4ebP5DL3ouc3av/rZ2pLuuefw8bEUPYMEcnd/UdK1EM8FhBLLbdiq5ubbPPmUuQqoG/yqXHnUfY/Lfi+mFFbn8noS6zwkbUi6uOT7pyXtSdo7ceJE+w2XGL2qC1xiWFSS1xM9638PtUXu4vOH6q0v289dt+971e+lsuCrCbW9IGhVIJ9/sCAIXaiyACiWRSVFwSjkmOZPWGtrYYNfmZNh3YC76vfaPgnGgECO0akSMGKZzRWdfNoaUx8nsLorbMv83uxEMh/Eh7TasyiQ036IwapSZIwlv1omhx1yTH301tft4inze7P20snkcC/7kAufodoPvyPpZ5LebWZXzOwTIZ4XaKJKkIqlRbCoJ3pe6DF13Vtft4snxRNzZ/Km6W0/SK0gNrHkyGdjSTU9ULZgXLewXPb3YkmVhSZ2PwSW67prpczrxdBJU1YbJ8Oq7z/lk2AZBHKMVozBsK0rgD7faxv7mVf5jGLvWgnxvw2BHKMUU8pkXhuX/n2/19D7vVf9jGJOp4T636YokLONLQZtYyN/E6fJJCvs9aWNrVdXvdfpNOvauHxZuu++7HvXrmXF0+3t5kXO0J911c8o5u1sQ302bGOLUarSvdDlvixtdMkse6+L2xVcvZo93MNtXRB6T5mqn1EsnUd5Wu+iyZumt/0gtYKulL3c7jot0cbrLXuvq1aMhkpBhMzRh8iRx5BGcw+X9hE5coxR2T/uPvKroQuTy97rqiDeJJfdprpdKzEVtt3bz5ETyDF4Zf64h3Jj5qL3WrSnStWTVqyBMgVtdq1Q7AQUb1E0FLPl3y9zD9Ex3ns0NhQ7gSViuvnDYtH17NnmRdjJZPn3ygTj2O49ijl50/S2H6RWEKMY0gZl9guvk1sNkaMdSvopZSK1AsSvKMWzqE7KZ76PvE7v+NDTTykgtQIkoGxfcZ3+46a7HMaUfsJBBHIgImUXr/SxyKWPvctRDoEciEiZ/cj7nAUXzeq7XBWLwwjkQETyZr1nzsQ9C15c/h9qyT/Ko9gJoBGKoN2h2AmgFaO7rVqECOQAGol518GxIJADaIS2xP4RyIGBa7ujhLbE/t3V9wAAtGdxo6tZR4kUNtBubRG4+8SMHBgwNroaBwI5MGB0lIwDgRwYMDpKxoFADgwYHSXjQCAHBoyOknGgawUYODpKho8ZOTAg7EI4TszIgYHoqmcc8WFGDgwEPePjRSAHItE0LULP+HgRyIEIhLg5Az3j4xUkkJvZKTN71cxeM7MvhHhOYExCpEXoGR+vxoHczNYkfUXSByU9JOljZvZQ0+cFxiREWoSe8fEK0bXyiKTX3P33kmRm35X0mKSXAzw3MAonTuTfLq1qWoSe8XEKkVq5X9If5r6+cvvYAWZ22sz2zGxvf38/wMsCw0FaBE10Vux09x1333T3zePHj3f1skASSIugiRCplT9KenDu6wduHwNQAWkR1BViRv5LSe8ys3ea2dskfVTSDwM8LwCghMaB3N1vSPq0pOclvSLpGXd/qenzAl1gbxIMQZC9Vtz9OUnPhXguoCvsTYKhYGUnRou9STAUBHKMFnuTYCgI5Bgt9ibBUBDIQ6JylhQW4WAoCOShhNi+Dp1iEQ6GgkAeSoqVM64gtLUlvf66dOtW9i9BHCniVm+hpFY5o/cOGAxm5DNNZ6epVc5SvIIAkItALoXJb6dWOUvtCgJAIQK5FGZ2mlrlLLUrCACFCORSuNlpSpWz1K4gQqHAiwEikEvVZqexBoKq40rtCiIEWkQxVO7e+ePkyZMeld1d9/V19+zPO3usr2fH6/xc12IdV2wmk4Of0ewxmfQ9MqAUSXueE1OZkUvlZ6dtdHqEmOHTgVIOBV4MlGVBvlubm5u+t7fX+es2duRINodbZJblxata7OWWsjx11RRH6HEN1cZG/h2OJ5OspgFEzswuuPvm4nFm5FVU7fRYNdsONZOmA6WcsRZ4MXgE8iqqBIIyhbWiS/pLl6qlW6qOK8ZibRfGWODFOOQlztt+RFfsrGJ3NyuOmWX/FhUUyxTWin7GrHrhssy4ui6Klv2sAJSigmInOfK2lMlb5+XIzfJ/L0Qet8sc8dmz0tNPH3wvdfL/AP6KHHnXyuSt8y71i06sIToruuramE4PB3GJThqgJQTytpTNWy+uBp1M8p8vROGyq6LouXPtnpC6Nua6ApJAIG9L3cLaqhNAUVApE2y66tpYFqxT66RhNShSkJc4b/uRdLGzC0VFwqJi5Zkz5YuYXRQglxVxUyt4shoUERHFzgEoKlaurUk3bx4+Pl/EnE6zlMfly9mseHu7vaJjURH3k5+UvvrVdl6zLSy2QkQodg5BUcoiL4jP/3xeeuDxx6Vjx9pJEeSllb797fSCuMRiKySBQJ6SouCxtrb85/NWkErS1avt5XtT2tJ3GVaDIgEE8ja01eVQFFROn14ebJYVH4fQEthmVwmrQZGCvMR5249BFzvbXD25u+t+9Oid5z169GAhtKiIWVSwmy9CpootfDEiYhvbjrS1pewsz3316p1jf/nLnf+eT2Vsb2evN5uhPvro4Rn7vJTzvWzhC9C1ElxbXQ5ll9cXbY37xBPSM88cPBHMvpdyqoCuEowIXStdaavLoezy+qIZ6nPPSW+8Ie3uxpnvrZvnpqsEIJAH11aXQ9mAtSrgt9VN0qTg2GT1JF0lAMXOVrSxerJsUa+PlYhNC45Nx8x2uRgJtbGy08w+Iuk/Jf2tpEfcvVTie9A58jYtrs589NEsZTK/WlMKc/u4Kppuj0ueGyilrRz5RUkflvRiw+dBGYudKd/61uF0hNR933NeEF92fBF5bqCRRoHc3V9x91dDDSYqsW9duqztrstVldNpdsLIY0aeG+hAZ8VOMzttZntmtre/v9/Vy9YT+9al02nxbLfr/b6X7T3uXq6fm9WTQCMrA7mZvWBmF3Mej1V5IXffcfdNd988fvx4/RF3IeZFJrOTTJGm6YgyVyLzP7MqfVL2xDKUvVmAHqwM5O7+AXd/OOfxbBcD7EWoW6K1kZ4p2gBLap6OKHMlsvgzq5DnBlpHH3meEMW3ttIzy04mTdMRZa5Elp1IFpHnBjrRKJCb2YfM7Iqk90n6kZk9H2ZYPQtRfGsrPVN0MplMmqcjylyJrLoqmW2pWzfPHXuRGYhRXnN5249aC4K6XvTR9PXM2tlpsM3d/soszGlzwRE7GQJLqWBBUBqBPMU/8LYDXhsntTKfc98nEmDE0g7kKf6Bx37yWXaD51UnibZOJG1dxQADURTI09jGNtUl3F3e8LjquLpexl9G06X+wMClvY1tqku4u+6NLlsojLVPnhWeQC1pBHL+wFer0u7YdG+UtrDCE6gljdSKFG+aIhZV0hJ33SXdvHn4Z9fWpBs32hgdgADSTq1IaS7hXkx1nD3bXo90ldWoeUF82XEAUUsnkKcmL9Xxta+VS33MnwCOHcseq4J/lTrCZJL/s0XHAUSNQB7aLAg//vjqpex5BcbFE8DVq9ljVfCvUkeg5gAMCoE8pPkgXFaZmyfPK+ouqVIopKgIDEo6xc4UFBUcl1ksRhb1zM+LvX8eQCvSL3amoOo2t3npjDK98bH3zwPoFIE8pGUBdjKRzpxZnc7Iy1/PI5cNYMFdfQ9gULa3my99n/3crGf+vvuyr69do38eQC4CeUiLQbhu4N3aIlgDKI3USmh9LVzihgzAaBHIU7W4aOjJJ8PfVg5AEgjkKcpbNPTWWwd/JobdDAF0gkCeorI3QK7aDgkgSQTytrSZsy4boOk3B0aBQD4vVPCtsjd4HWUCNP3mwGgQyGdCBt+278CTt2jo7rulo0f72TuFjhmgV+y1MhPyfpFd3GM0lhttxHr/T2CA2GulyGw2WbTZVZ2CYVHq48iRcLPWWG60Eev9P4ERGXcgL7PtbJ2CYdF+KTdvDq/Pu8qdiQC0YtyBfFUbX92C4eJ+32trh39mKLPWKncmAtCKcQfyZbPGpgXD+dRHUV481llrleIldxsCejfuQF40a5wVOEPlnVOata7q3lkM8hJ3GwJ6Nu7dD4u2nQ05m5xOpTffPHw81lnrquLl/Oc1C/I7O9U7ewAEM+4Zedv3rpzNbq9ePXj86NF4Z63Lipd0qABRoo+8TSF707uybMyXL7ffHw+gEH3kfUixNW9Z8TKlXD8wIgTyNqUY+Jalm+hQAaLUKJCb2ZfM7Ldm9hsz+4GZvT3QuIYh1cBXtGq07ZoCgFqazsjPS3rY3d8j6XeSvth8SAMyxMAXy9YAAP6qUfuhu/9k7sufS/rXZsMZIG6kDKBlIXPkT0r6cdE3zey0me2Z2d7+/n7Alx0gtoUFUMHKGbmZvSDpHTnfOufuz97+mXOSbkgqjDjuviNpR8raD2uNdgwWt4WdLbqRmNkDyNW4j9zMPi7p3yW9391L3EhyRH3kdaTYew6gE0V95I1y5GZ2StLnJf1j2SCOFVLsPQfQq6Y58qck3SvpvJn92syeDjCmcUux9xxAr5p2rfxNqIHgti428gIwKKzsjM0Qe88BtGrc29jGit5zABUwIweAxBHIASBxBHIASByBHAASRyAHgMT1cqs3M9uXlLMOvbZjkt4I+HxDweeSj8/lMD6TfLF9LhN3P754sJdAHpqZ7eXtPzB2fC75+FwO4zPJl8rnQmoFABJHIAeAxA0lkO/0PYBI8bnk43M5jM8kXxKfyyBy5AAwZkOZkQPAaBHIASBxgwnkZvYlM/utmf3GzH5gZm/ve0wxMLOPmNlLZnbLzKJvo2qTmZ0ys1fN7DUz+0Lf44mBmX3DzP5kZhf7HktMzOxBM/upmb18++/nM32PaZnBBHJJ5yU97O7vkfQ7SV/seTyxuCjpw5Je7HsgfTKzNUlfkfRBSQ9J+piZPdTvqKLwTUmn+h5EhG5I+py7PyTp7yR9Kub/vwwmkLv7T9z9xu0vfy7pgT7HEwt3f8XdX+17HBF4RNJr7v57d39L0nclPdbzmHrn7i9Kutb3OGLj7v/r7r+6/d//J+kVSff3O6pigwnkC56U9OO+B4Go3C/pD3NfX1HEf5iIh5ltSHqvpF/0PJRCSd0hyMxekPSOnG+dc/dnb//MOWWXRdMux9anMp8LgOrM7B5J35P0WXf/c9/jKZJUIHf3Dyz7vpl9XNI/S3q/j6hBftXnAknSHyU9OPf1A7ePAbnM7G5lQXzq7t/vezzLDCa1YmanJH1e0r+4+/VVP4/R+aWkd5nZO83sbZI+KumHPY8JkTIzk/R1Sa+4+5f7Hs8qgwnkkp6SdK+k82b2azN7uu8BxcDMPmRmVyS9T9KPzOz5vsfUh9uF8E9Lel5Z4eoZd3+p31H1z8y+I+lnkt5tZlfM7BN9jykSfy/p3yT90+148msze7TvQRVhiT4AJG5IM3IAGCUCOQAkjkAOAIkjkANA4gjkAJA4AjkAJI5ADgCJ+39i6MRVJzJlmwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for idx, label in enumerate(labels):\n", " if label:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"b\") \n", " else:\n", " plt.scatter(df1.loc[idx][\"x\"], df1.loc[idx][\"y\"], color=\"r\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyQUBOのSAのときと同様の結果を得ることができました。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## QBoost\n", "\n", "QBoostは量子アニーリングを用いたアンサンブル学習の一つです。アンサンブル学習は弱い予測器を多数用意して、その予測器の各予測結果の組み合わせて最終的な予測結果を得ます。 \n", "QBoostでは、与えられた学習データに対して最適な学習器の組み合わせを、量子アニーリングを用いて最適化を行います。今回は分類問題を扱います。 \n", "\n", "$D$個の学習データの集合を$\\{\\vec x^{(d)}\\}(d=1, ..., D)$、対応するラベルを$\\{y^{(d)}\\}(d=1, ..., D), y^{(d)}\\in \\{-1, 1\\}$とします。また、$N$個の弱学習器の(関数の)集合を$\\{C_i\\}(i=1, ..., N)$とします。あるデータ$\\vec x^{(d)}$ に対して、$C_i(\\vec x^{(d)})\\in \\{-1, 1\\}$です。このとき、最終的な分類のラベルは以下のようになります。\n", "\n", "$${\\rm sgn}\\left( \\sum_{i=1}^{N} w_i C_i({\\vec x}^{(d)})\\right)$$\n", "\n", "ここで$w_i\\in\\{0, 1\\} (i=1, ..., N)$とします。これは各予測器の重み(予測器を最終的な予測に採用するか採用しないかのbool値)です。 \n", "QBoostではこの$w_i$を、弱学習器の個数を刈り込みつつ最終的な予測が教師データに一致するように組み合わせの最適化を行います。\n", "この問題におけるハミルトニアンは、以下のようになります。\n", "\n", "$$H(\\vec w) = \\sum_{d=1}^{D} \\left( \\frac{1}{N}\\sum_{i=1}^{N} w_i C_i(\\vec x^{(d)})-y^{(d)} \\right)^2 + \\lambda \\sum _i^N w_i$$\n", "\n", "このハミルトニアンの第一項目は、弱分類器と正解ラベルの差を表しています。第二項目は最終的な分類器に採用する弱分類器の個数の程度を表しています。$\\lambda$は第二項の弱分類器の個数がハミルトニアンにどのくらい影響するかを調節する正則化パラメータにです。 \n", "第一項をコスト(目的関数)、第二項を制約として、このハミルトニアンの最適化を行います。量子アニーリングで最小化することにより、学習データに最も適合するような弱分類器の組み合わせを得ることができます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### スクリプト\n", "\n", "それでは実際にQBoostを試してみましょう。学習データにはscikit-learnの癌識別のデータセットを使用します。簡単のために、学習に用いるのは\"0\"と\"1\"の2つの文字種のみとします。" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# ライブラリをインポート\n", "import pandas as pd \n", "from scipy import stats \n", "from sklearn import datasets\n", "from sklearn import metrics" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# データのロード\n", "cancerdata = datasets.load_breast_cancer()\n", "# 学習用データと検証用データの個数の設定\n", "num_train = 450" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "今回はデモンストレーションのために、ノイズとなる特徴量がある場合を考えます。" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(569, 60)\n" ] } ], "source": [ "data_noisy = np.concatenate((cancerdata.data, np.random.rand(cancerdata.data.shape[0], 30)), axis=1)\n", "print(data_noisy.shape)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# labelを{0, 1}から{-1, 1}に変換\n", "labels = (cancerdata.target-0.5) * 2" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# データセットを学習用と検証用に分割\n", "X_train = data_noisy[:num_train, :]\n", "X_test = data_noisy[num_train:, :]\n", "y_train = labels[:num_train]\n", "y_test = labels[num_train:]" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# 弱学習器の結果から\n", "def aggre_mean(Y_list):\n", " return ((np.mean(Y_list, axis=0)>0)-0.5) * 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 弱学習器の作成\n", "\n", "scikit-learnで弱分類器を作成します。今回は弱分類器としてdecision stumpを用いましょう。decision stumpとは、一層の決定木のことです。今回は弱分類器として用いるので、分割に用いる特徴量はランダムに選ぶこととします(一層のランダムフォレストを行うという理解で良いでしょう)。" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# 必要なライブラリのインポート\n", "from sklearn.tree import DecisionTreeClassifier as DTC\n", "\n", "# 弱分類機の個数の設定\n", "num_clf = 32\n", "# bootstrap samplingで、一つのサンプルに対して取り出すサンブル個数\n", "sample_train = 40\n", "# モデルの設定\n", "models = [DTC(splitter=\"random\",max_depth=1) for i in range(num_clf)]\n", "for model in models:\n", " # ランダムに抽出\n", " train_idx = np.random.choice(np.arange(X_train.shape[0]), sample_train)\n", " # 説明変数と目的変数から決定木を作成\n", " model.fit(X=X_train[train_idx], y=y_train[train_idx])\n", "y_pred_list_train = []\n", "for model in models:\n", " # 作成したモデルを用いて予測を実行\n", " y_pred_list_train.append(model.predict(X_train))\n", "y_pred_list_train = np.asanyarray(y_pred_list_train)\n", "y_pred_train =np.sign(y_pred_list_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "すべての弱学習器を最終的な分類器としたときの精度を見てみましょう。以後、この組み合わせをbaselineとします。" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9327731092436975\n" ] } ], "source": [ "y_pred_list_test = []\n", "for model in models:\n", " # 検証データで実行\n", " y_pred_list_test.append(model.predict(X_test))\n", " \n", "y_pred_list_test = np.array(y_pred_list_test)\n", "y_pred_test = np.sign(np.sum(y_pred_list_test,axis=0))\n", "# 予測結果の精度のスコア計算\n", "acc_test_base = metrics.accuracy_score(y_true=y_test, y_pred=y_pred_test)\n", "print(acc_test_base)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# QBoostを行うクラスを定義\n", "class QBoost():\n", " def __init__(self, y_train, ys_pred):\n", " self.num_clf = ys_pred.shape[0]\n", " # バイナリ変数を定義\n", " self.Ws = Array.create(\"weight\", shape = self.num_clf, vartype=\"BINARY\")\n", " # 正規化項の大きさ(ハイパーパラメータ)をPyQUBOのPlaceholderで定義\n", " self.param_lamda = Placeholder(\"norm\")\n", " # 弱分類器の組み合わせのハミルトニアンをセット\n", " self.H_clf = sum( [ (1/self.num_clf * sum([W*C for W, C in zip(self.Ws, y_clf)])- y_true)**2 for y_true, y_clf in zip(y_train, ys_pred.T)\n", " ])\n", " # 正規化項のハミルトニアンを制約としてセット\n", " self.H_norm = Constraint(sum([W for W in self.Ws]), \"norm\")\n", " # 全ハミルトニアンをセット\n", " self.H = self.H_clf + self.H_norm * self.param_lamda\n", " # モデルをコンパイル\n", " self.model = self.H.compile()\n", " # QUBOに変換する関数を定義\n", " def to_qubo(self, norm_param=1):\n", " # ハイパーパラメータの大きさを設定\n", " self.feed_dict = {'norm': norm_param}\n", " return self.model.to_qubo(feed_dict=self.feed_dict)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "qboost = QBoost(y_train=y_train, ys_pred=y_pred_list_train)\n", "# lambda=3としてQUBO作成\n", "qubo = qboost.to_qubo(3)[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### D-Waveサンプラーを用いてQBoostを実行" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 必要なライブラリをインポート\n", "from dwave.system.samplers import DWaveSampler\n", "from dwave.system.composites import EmbeddingComposite" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dw = DWaveSampler(endpoint='https://cloud.dwavesys.com/sapi/', \n", " token='xxxx', \n", " solver='DW_2000Q_VFYC_6')\n", "sampler = EmbeddingComposite(dw)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# D-Waveサンプラーで計算\n", "sampleset = sampler.sample_qubo(qubo, num_reads=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 結果の確認\n", "print(sampleset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 各計算結果をPyQUBOでdecode\n", "decoded_solutions = []\n", "brokens = []\n", "energies =[]\n", "\n", "decoded_sol = qboost.model.decode_dimod_response(sampleset, feed_dict=qboost.feed_dict)\n", "for d_sol, broken, energy in decoded_sol:\n", " decoded_solutions.append(d_sol)\n", " brokens.append(broken)\n", " energies.append(energy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "D-Waveで得られた弱分類器の組み合わせを使った場合の、学習データ/検証データでの精度を確認しましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "accs_train_Dwaves = []\n", "accs_test_Dwaves = []\n", "for decoded_solution in decoded_solutions:\n", " idx_clf_DWave=[]\n", " for key, val in decoded_solution[\"weight\"].items():\n", " if val == 1:\n", " idx_clf_DWave.append(int(key))\n", " y_pred_train_DWave = np.sign(np.sum(y_pred_list_train[idx_clf_DWave, :], axis=0))\n", " y_pred_test_DWave = np.sign(np.sum(y_pred_list_test[idx_clf_DWave, :], axis=0))\n", " acc_train_DWave = metrics.accuracy_score(y_true=y_train, y_pred=y_pred_train_DWave)\n", " acc_test_DWave= metrics.accuracy_score(y_true=y_test, y_pred=y_pred_test_DWave)\n", " accs_train_Dwaves.append(acc_train_DWave)\n", " accs_test_Dwaves.append(acc_test_DWave)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "横軸をエネルギー、縦軸を精度のグラフを作成しましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(12, 8))\n", "plt.scatter(energies, accs_train_Dwaves, label=\"train\" )\n", "plt.scatter(energies, accs_test_Dwaves, label=\"test\")\n", "plt.xlabel(\"energy: D-wave\")\n", "plt.ylabel(\"accuracy\")\n", "plt.title(\"relationship between energy and accuracy\")\n", "plt.grid()\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"base accuracy is {}\".format(acc_test_base))\n", "print(\"max accuracy of QBoost is {}\".format(max(accs_test_Dwaves)))\n", "print(\"average accuracy of QBoost is {}\".format(np.mean(np.asarray(accs_test_Dwaves))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "D-Waveによるサンプリングは短時間で数百サンプリング以上が実行できます。精度が最大になる結果を使えば、baselineよりも高精度の分類器を作成することが可能です。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.3" } }, "nbformat": 4, "nbformat_minor": 2 }