diff --git a/Reconstruct_and_RoBERTa_baseline_train_dev_dataset.ipynb b/Reconstruct_and_RoBERTa_baseline_train_dev_dataset.ipynb
index 44057e8436ffecffd1e2969fe64ae221caba5534..483eb658a3af375d40a990a937316c477be2c2a4 100644
--- a/Reconstruct_and_RoBERTa_baseline_train_dev_dataset.ipynb
+++ b/Reconstruct_and_RoBERTa_baseline_train_dev_dataset.ipynb
@@ -27,7 +27,33 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
+      "Collecting tensorboardx\n",
+      "  Downloading tensorboardX-2.5-py2.py3-none-any.whl (125 kB)\n",
+      "\u001b[K     |████████████████████████████████| 125 kB 15.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from tensorboardx) (1.21.2)\n",
+      "Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from tensorboardx) (1.16.0)\n",
+      "Requirement already satisfied: protobuf>=3.8.0 in /opt/conda/lib/python3.8/site-packages (from tensorboardx) (3.18.1)\n",
+      "Installing collected packages: tensorboardx\n",
+      "Successfully installed tensorboardx-2.5\n",
+      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install tensorboardx"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/",
@@ -36,15 +62,331 @@
     "id": "hYhFR7nSYOjG",
     "outputId": "23ed0686-29d3-45ff-dc22-b2fe54e86ec4"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
+      "Collecting simpletransformers\n",
+      "  Downloading simpletransformers-0.63.4-py3-none-any.whl (248 kB)\n",
+      "\u001b[K     |████████████████████████████████| 248 kB 26.8 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting tensorflow\n",
+      "  Downloading tensorflow-2.8.0-cp38-cp38-manylinux2010_x86_64.whl (497.6 MB)\n",
+      "\u001b[K     |██████████████████████▍         | 348.8 MB 92.8 MB/s eta 0:00:02"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "IOPub data rate exceeded.\n",
+      "The notebook server will temporarily stop sending output\n",
+      "to the client in order to avoid crashing it.\n",
+      "To change this limit, set the config variable\n",
+      "`--NotebookApp.iopub_data_rate_limit`.\n",
+      "\n",
+      "Current values:\n",
+      "NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
+      "NotebookApp.rate_limit_window=3.0 (secs)\n",
+      "\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\u001b[K     |████████████████████████████████| 497.6 MB 80.8 MB/s eta 0:00:011\n",
+      "\u001b[?25hCollecting sentencepiece\n",
+      "  Downloading sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n",
+      "\u001b[K     |████████████████████████████████| 1.2 MB 82.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: tqdm>=4.47.0 in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (4.62.3)\n",
+      "Collecting transformers>=4.6.0\n",
+      "  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)\n",
+      "\u001b[K     |████████████████████████████████| 3.5 MB 50.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting tokenizers\n",
+      "  Downloading tokenizers-0.11.6-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)\n",
+      "\u001b[K     |████████████████████████████████| 6.5 MB 38.2 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting streamlit\n",
+      "  Downloading streamlit-1.6.0-py2.py3-none-any.whl (9.7 MB)\n",
+      "\u001b[K     |████████████████████████████████| 9.7 MB 33.1 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: regex in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (2021.10.8)\n",
+      "Requirement already satisfied: scipy in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (1.6.3)\n",
+      "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (1.0)\n",
+      "Collecting wandb>=0.10.32\n",
+      "  Downloading wandb-0.12.11-py2.py3-none-any.whl (1.7 MB)\n",
+      "\u001b[K     |████████████████████████████████| 1.7 MB 72.7 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (2.26.0)\n",
+      "Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (1.21.2)\n",
+      "Collecting datasets\n",
+      "  Downloading datasets-1.18.3-py3-none-any.whl (311 kB)\n",
+      "\u001b[K     |████████████████████████████████| 311 kB 74.4 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting seqeval\n",
+      "  Downloading seqeval-1.2.2.tar.gz (43 kB)\n",
+      "\u001b[K     |████████████████████████████████| 43 kB 50.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting pandas\n",
+      "  Downloading pandas-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.7 MB)\n",
+      "\u001b[K     |████████████████████████████████| 11.7 MB 66.2 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: tensorboard in /opt/conda/lib/python3.8/site-packages (from simpletransformers) (2.6.0)\n",
+      "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.41.0)\n",
+      "Requirement already satisfied: absl-py>=0.4.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (0.14.1)\n",
+      "Collecting flatbuffers>=1.12\n",
+      "  Downloading flatbuffers-2.0-py2.py3-none-any.whl (26 kB)\n",
+      "Collecting tensorflow-io-gcs-filesystem>=0.23.1\n",
+      "  Downloading tensorflow_io_gcs_filesystem-0.24.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.1 MB)\n",
+      "\u001b[K     |████████████████████████████████| 2.1 MB 73.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting keras<2.9,>=2.8.0rc0\n",
+      "  Downloading keras-2.8.0-py2.py3-none-any.whl (1.4 MB)\n",
+      "\u001b[K     |████████████████████████████████| 1.4 MB 17.6 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting wrapt>=1.11.0\n",
+      "  Downloading wrapt-1.13.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (84 kB)\n",
+      "\u001b[K     |████████████████████████████████| 84 kB 52.1 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting astunparse>=1.6.0\n",
+      "  Downloading astunparse-1.6.3-py2.py3-none-any.whl (12 kB)\n",
+      "Collecting libclang>=9.0.1\n",
+      "  Downloading libclang-13.0.0-py2.py3-none-manylinux1_x86_64.whl (14.5 MB)\n",
+      "\u001b[K     |████████████████████████████████| 14.5 MB 72.1 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting tf-estimator-nightly==2.8.0.dev2021122109\n",
+      "  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)\n",
+      "\u001b[K     |████████████████████████████████| 462 kB 76.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: setuptools in /opt/conda/lib/python3.8/site-packages (from tensorflow) (58.2.0)\n",
+      "Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.16.0)\n",
+      "Collecting keras-preprocessing>=1.1.1\n",
+      "  Downloading Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)\n",
+      "\u001b[K     |████████████████████████████████| 42 kB 41.2 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting gast>=0.2.1\n",
+      "  Downloading gast-0.5.3-py3-none-any.whl (19 kB)\n",
+      "Requirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (3.10.0.2)\n",
+      "Collecting termcolor>=1.1.0\n",
+      "  Downloading termcolor-1.1.0.tar.gz (3.9 kB)\n",
+      "Collecting h5py>=2.9.0\n",
+      "  Downloading h5py-3.6.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)\n",
+      "\u001b[K     |████████████████████████████████| 4.5 MB 65.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting google-pasta>=0.1.1\n",
+      "  Downloading google_pasta-0.2.0-py3-none-any.whl (57 kB)\n",
+      "\u001b[K     |████████████████████████████████| 57 kB 64.7 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting opt-einsum>=2.3.2\n",
+      "  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n",
+      "\u001b[K     |████████████████████████████████| 65 kB 67.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting tensorboard\n",
+      "  Downloading tensorboard-2.8.0-py3-none-any.whl (5.8 MB)\n",
+      "\u001b[K     |████████████████████████████████| 5.8 MB 74.3 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: protobuf>=3.9.2 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (3.18.1)\n",
+      "Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/conda/lib/python3.8/site-packages (from astunparse>=1.6.0->tensorflow) (0.37.0)\n",
+      "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard->simpletransformers) (0.6.1)\n",
+      "Requirement already satisfied: google-auth<3,>=1.6.3 in /opt/conda/lib/python3.8/site-packages (from tensorboard->simpletransformers) (1.35.0)\n",
+      "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard->simpletransformers) (1.8.0)\n",
+      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.8/site-packages (from tensorboard->simpletransformers) (0.4.6)\n",
+      "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.8/site-packages (from tensorboard->simpletransformers) (3.3.4)\n",
+      "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.8/site-packages (from tensorboard->simpletransformers) (2.0.2)\n",
+      "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->simpletransformers) (4.7.2)\n",
+      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->simpletransformers) (4.2.4)\n",
+      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->simpletransformers) (0.2.8)\n",
+      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard->simpletransformers) (1.3.0)\n",
+      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->simpletransformers) (0.4.8)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->simpletransformers) (2021.5.30)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->simpletransformers) (2.0.0)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->simpletransformers) (3.1)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->simpletransformers) (1.26.7)\n",
+      "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard->simpletransformers) (3.1.1)\n",
+      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.8/site-packages (from transformers>=4.6.0->simpletransformers) (21.0)\n",
+      "Requirement already satisfied: sacremoses in /opt/conda/lib/python3.8/site-packages (from transformers>=4.6.0->simpletransformers) (0.0.46)\n",
+      "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.8/site-packages (from transformers>=4.6.0->simpletransformers) (5.4.1)\n",
+      "Requirement already satisfied: filelock in /opt/conda/lib/python3.8/site-packages (from transformers>=4.6.0->simpletransformers) (3.3.0)\n",
+      "Collecting huggingface-hub<1.0,>=0.1.0\n",
+      "  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)\n",
+      "\u001b[K     |████████████████████████████████| 67 kB 72.6 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging>=20.0->transformers>=4.6.0->simpletransformers) (2.4.7)\n",
+      "Requirement already satisfied: psutil>=5.0.0 in /opt/conda/lib/python3.8/site-packages (from wandb>=0.10.32->simpletransformers) (5.8.0)\n",
+      "Collecting docker-pycreds>=0.4.0\n",
+      "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
+      "Collecting setproctitle\n",
+      "  Downloading setproctitle-1.2.2-cp38-cp38-manylinux1_x86_64.whl (36 kB)\n",
+      "Collecting shortuuid>=0.5.0\n",
+      "  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)\n",
+      "Requirement already satisfied: Click!=8.0.0,>=7.0 in /opt/conda/lib/python3.8/site-packages (from wandb>=0.10.32->simpletransformers) (8.0.1)\n",
+      "Collecting pathtools\n",
+      "  Downloading pathtools-0.1.2.tar.gz (11 kB)\n",
+      "Requirement already satisfied: python-dateutil>=2.6.1 in /opt/conda/lib/python3.8/site-packages (from wandb>=0.10.32->simpletransformers) (2.8.2)\n",
+      "Collecting GitPython>=1.0.0\n",
+      "  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)\n",
+      "\u001b[K     |████████████████████████████████| 181 kB 70.6 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting yaspin>=1.0.0\n",
+      "  Downloading yaspin-2.1.0-py3-none-any.whl (18 kB)\n",
+      "Collecting sentry-sdk>=1.0.0\n",
+      "  Downloading sentry_sdk-1.5.6-py2.py3-none-any.whl (144 kB)\n",
+      "\u001b[K     |████████████████████████████████| 144 kB 83.1 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting promise<3,>=2.0\n",
+      "  Downloading promise-2.3.tar.gz (19 kB)\n",
+      "Collecting gitdb<5,>=4.0.1\n",
+      "  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)\n",
+      "\u001b[K     |████████████████████████████████| 63 kB 47.6 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting smmap<6,>=3.0.1\n",
+      "  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n",
+      "Collecting xxhash\n",
+      "  Downloading xxhash-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n",
+      "\u001b[K     |████████████████████████████████| 212 kB 81.2 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting multiprocess\n",
+      "  Downloading multiprocess-0.70.12.2-py38-none-any.whl (128 kB)\n",
+      "\u001b[K     |████████████████████████████████| 128 kB 72.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting pyarrow!=4.0.0,>=3.0.0\n",
+      "  Downloading pyarrow-7.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)\n",
+      "\u001b[K     |████████████████████████████████| 26.7 MB 18.1 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting fsspec[http]>=2021.05.0\n",
+      "  Downloading fsspec-2022.2.0-py3-none-any.whl (134 kB)\n",
+      "\u001b[K     |████████████████████████████████| 134 kB 85.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting dill\n",
+      "  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)\n",
+      "\u001b[K     |████████████████████████████████| 86 kB 74.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting aiohttp\n",
+      "  Downloading aiohttp-3.8.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.3 MB)\n",
+      "\u001b[K     |████████████████████████████████| 1.3 MB 76.6 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting aiosignal>=1.1.2\n",
+      "  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n",
+      "Collecting async-timeout<5.0,>=4.0.0a3\n",
+      "  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
+      "Collecting frozenlist>=1.1.1\n",
+      "  Downloading frozenlist-1.3.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)\n",
+      "\u001b[K     |████████████████████████████████| 158 kB 81.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->datasets->simpletransformers) (21.2.0)\n",
+      "Collecting yarl<2.0,>=1.0\n",
+      "  Downloading yarl-1.7.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (308 kB)\n",
+      "\u001b[K     |████████████████████████████████| 308 kB 81.1 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting multidict<7.0,>=4.5\n",
+      "  Downloading multidict-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (121 kB)\n",
+      "\u001b[K     |████████████████████████████████| 121 kB 81.8 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.8/site-packages (from pandas->simpletransformers) (2021.3)\n",
+      "Requirement already satisfied: joblib in /opt/conda/lib/python3.8/site-packages (from sacremoses->transformers>=4.6.0->simpletransformers) (1.1.0)\n",
+      "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->simpletransformers) (3.0.0)\n",
+      "Collecting watchdog\n",
+      "  Downloading watchdog-2.1.6-py3-none-manylinux2014_x86_64.whl (76 kB)\n",
+      "\u001b[K     |████████████████████████████████| 76 kB 60.9 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting blinker\n",
+      "  Downloading blinker-1.4.tar.gz (111 kB)\n",
+      "\u001b[K     |████████████████████████████████| 111 kB 76.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting validators\n",
+      "  Downloading validators-0.18.2-py3-none-any.whl (19 kB)\n",
+      "Requirement already satisfied: tornado>=5.0 in /opt/conda/lib/python3.8/site-packages (from streamlit->simpletransformers) (6.1)\n",
+      "Collecting semver\n",
+      "  Downloading semver-2.13.0-py2.py3-none-any.whl (12 kB)\n",
+      "Collecting astor\n",
+      "  Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)\n",
+      "Collecting tzlocal\n",
+      "  Downloading tzlocal-4.1-py3-none-any.whl (19 kB)\n",
+      "Collecting base58\n",
+      "  Downloading base58-2.1.1-py3-none-any.whl (5.6 kB)\n",
+      "Collecting altair>=3.2.0\n",
+      "  Downloading altair-4.2.0-py3-none-any.whl (812 kB)\n",
+      "\u001b[K     |████████████████████████████████| 812 kB 79.3 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting pympler>=0.9\n",
+      "  Downloading Pympler-1.0.1-py3-none-any.whl (164 kB)\n",
+      "\u001b[K     |████████████████████████████████| 164 kB 81.5 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting pydeck>=0.1.dev5\n",
+      "  Downloading pydeck-0.7.1-py2.py3-none-any.whl (4.3 MB)\n",
+      "\u001b[K     |████████████████████████████████| 4.3 MB 70.7 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting importlib-metadata>=1.4\n",
+      "  Downloading importlib_metadata-4.11.2-py3-none-any.whl (17 kB)\n",
+      "Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.8/site-packages (from streamlit->simpletransformers) (8.2.0)\n",
+      "Requirement already satisfied: toml in /opt/conda/lib/python3.8/site-packages (from streamlit->simpletransformers) (0.10.2)\n",
+      "Requirement already satisfied: entrypoints in /opt/conda/lib/python3.8/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (0.3)\n",
+      "Collecting toolz\n",
+      "  Downloading toolz-0.11.2-py3-none-any.whl (55 kB)\n",
+      "\u001b[K     |████████████████████████████████| 55 kB 63.4 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: jsonschema>=3.0 in /opt/conda/lib/python3.8/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (4.0.1)\n",
+      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from altair>=3.2.0->streamlit->simpletransformers) (3.0.1)\n",
+      "Collecting zipp>=0.5\n",
+      "  Downloading zipp-3.7.0-py3-none-any.whl (5.3 kB)\n",
+      "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema>=3.0->altair>=3.2.0->streamlit->simpletransformers) (0.18.0)\n",
+      "Requirement already satisfied: traitlets>=4.3.2 in /opt/conda/lib/python3.8/site-packages (from pydeck>=0.1.dev5->streamlit->simpletransformers) (5.1.0)\n",
+      "Collecting ipywidgets>=7.0.0\n",
+      "  Downloading ipywidgets-7.6.5-py2.py3-none-any.whl (121 kB)\n",
+      "\u001b[K     |████████████████████████████████| 121 kB 84.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: ipykernel>=5.1.2 in /opt/conda/lib/python3.8/site-packages (from pydeck>=0.1.dev5->streamlit->simpletransformers) (6.4.1)\n",
+      "Requirement already satisfied: ipython-genutils in /opt/conda/lib/python3.8/site-packages (from ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.2.0)\n",
+      "Requirement already satisfied: debugpy<2.0,>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.5.0)\n",
+      "Requirement already satisfied: ipython<8.0,>=7.23.1 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (7.28.0)\n",
+      "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.1.3)\n",
+      "Requirement already satisfied: jupyter-client<8.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (7.0.6)\n",
+      "Requirement already satisfied: pickleshare in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.5)\n",
+      "Requirement already satisfied: backcall in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.2.0)\n",
+      "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.8.0)\n",
+      "Requirement already satisfied: pygments in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (2.10.0)\n",
+      "Requirement already satisfied: decorator in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.1.0)\n",
+      "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.18.0)\n",
+      "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (3.0.20)\n",
+      "Collecting jupyterlab-widgets>=1.0.0\n",
+      "  Downloading jupyterlab_widgets-1.0.2-py3-none-any.whl (243 kB)\n",
+      "\u001b[K     |████████████████████████████████| 243 kB 81.0 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: nbformat>=4.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (5.1.3)\n",
+      "Collecting widgetsnbextension~=3.5.0\n",
+      "  Downloading widgetsnbextension-3.5.2-py2.py3-none-any.whl (1.6 MB)\n",
+      "\u001b[K     |████████████████████████████████| 1.6 MB 60.4 MB/s eta 0:00:01\n",
+      "\u001b[?25hRequirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.8/site-packages (from jedi>=0.16->ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.8.2)\n",
+      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->altair>=3.2.0->streamlit->simpletransformers) (2.0.1)\n",
+      "Requirement already satisfied: jupyter-core>=4.6.0 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.8.1)\n",
+      "Requirement already satisfied: nest-asyncio>=1.5 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.5.1)\n",
+      "Requirement already satisfied: pyzmq>=13 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (22.3.0)\n",
+      "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.8/site-packages (from pexpect>4.3->ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.0)\n",
+      "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython<8.0,>=7.23.1->ipykernel>=5.1.2->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.2.5)\n",
+      "Requirement already satisfied: notebook>=4.4.1 in /opt/conda/lib/python3.8/site-packages (from widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (6.4.1)\n",
+      "Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.11.0)\n",
+      "Requirement already satisfied: nbconvert in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (6.2.0)\n",
+      "Requirement already satisfied: terminado>=0.8.3 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.12.1)\n",
+      "Requirement already satisfied: argon2-cffi in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (21.1.0)\n",
+      "Requirement already satisfied: Send2Trash>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.8.0)\n",
+      "Requirement already satisfied: cffi>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.14.6)\n",
+      "Requirement already satisfied: pycparser in /opt/conda/lib/python3.8/site-packages (from cffi>=1.0.0->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (2.20)\n",
+      "Requirement already satisfied: defusedxml in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.7.1)\n",
+      "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.4)\n",
+      "Requirement already satisfied: jupyterlab-pygments in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.1.2)\n",
+      "Requirement already satisfied: pandocfilters>=1.4.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (1.5.0)\n",
+      "Requirement already satisfied: bleach in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (4.1.0)\n",
+      "Requirement already satisfied: testpath in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.0)\n",
+      "Requirement already satisfied: mistune<2,>=0.8.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.8.4)\n",
+      "Requirement already satisfied: webencodings in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets>=7.0.0->pydeck>=0.1.dev5->streamlit->simpletransformers) (0.5.1)\n",
+      "Collecting backports.zoneinfo\n",
+      "  Downloading backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl (74 kB)\n",
+      "\u001b[K     |████████████████████████████████| 74 kB 58.5 MB/s eta 0:00:01\n",
+      "\u001b[?25hCollecting pytz-deprecation-shim\n",
+      "  Downloading pytz_deprecation_shim-0.1.0.post0-py2.py3-none-any.whl (15 kB)\n",
+      "Collecting tzdata\n",
+      "  Downloading tzdata-2021.5-py2.py3-none-any.whl (339 kB)\n",
+      "\u001b[K     |████████████████████████████████| 339 kB 62.2 MB/s eta 0:00:01\n",
+      "\u001b[?25hBuilding wheels for collected packages: termcolor, promise, pathtools, seqeval, blinker\n",
+      "  Building wheel for termcolor (setup.py) ... \u001b[?25ldone\n",
+      "\u001b[?25h  Created wheel for termcolor: filename=termcolor-1.1.0-py3-none-any.whl size=4847 sha256=3ece8a62835e9d7dd9791f70ed28b1e3c95a92827dd9eb17fefa31935224b9df\n",
+      "  Stored in directory: /tmp/pip-ephem-wheel-cache-dlo3tsqx/wheels/a0/16/9c/5473df82468f958445479c59e784896fa24f4a5fc024b0f501\n",
+      "  Building wheel for promise (setup.py) ... \u001b[?25ldone\n",
+      "\u001b[?25h  Created wheel for promise: filename=promise-2.3-py3-none-any.whl size=21502 sha256=a8d4205eca5eb956ec27379f492743ab576a1bdc6bf2af77f30c1f10a8e2ea84\n",
+      "  Stored in directory: /tmp/pip-ephem-wheel-cache-dlo3tsqx/wheels/54/aa/01/724885182f93150035a2a91bce34a12877e8067a97baaf5dc8\n",
+      "  Building wheel for pathtools (setup.py) ... \u001b[?25ldone\n",
+      "\u001b[?25h  Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8807 sha256=aa094f3a6bdb346ad900a98238b5714ae8af375c7e4aba9cb38ad86c3ce2e108\n",
+      "  Stored in directory: /tmp/pip-ephem-wheel-cache-dlo3tsqx/wheels/4c/8e/7e/72fbc243e1aeecae64a96875432e70d4e92f3d2d18123be004\n",
+      "  Building wheel for seqeval (setup.py) ... \u001b[?25ldone\n",
+      "\u001b[?25h  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16181 sha256=12e46e3f110593059f01255ee98fe0f9420ec215eca85bf76482b129596f0835\n",
+      "  Stored in directory: /tmp/pip-ephem-wheel-cache-dlo3tsqx/wheels/ad/5c/ba/05fa33fa5855777b7d686e843ec07452f22a66a138e290e732\n",
+      "  Building wheel for blinker (setup.py) ... \u001b[?25ldone\n",
+      "\u001b[?25h  Created wheel for blinker: filename=blinker-1.4-py3-none-any.whl size=13478 sha256=098831030bd0adc9bc71e2cda47db0f0aaf34363f1f8c91e2f327837814eb52b\n",
+      "  Stored in directory: /tmp/pip-ephem-wheel-cache-dlo3tsqx/wheels/b7/a5/68/fe632054a5eadd531c7a49d740c50eb6adfbeca822b4eab8d4\n",
+      "Successfully built termcolor promise pathtools seqeval blinker\n",
+      "Installing collected packages: multidict, frozenlist, yarl, widgetsnbextension, tzdata, smmap, jupyterlab-widgets, backports.zoneinfo, async-timeout, aiosignal, zipp, toolz, termcolor, pytz-deprecation-shim, pandas, ipywidgets, gitdb, fsspec, dill, aiohttp, yaspin, xxhash, watchdog, validators, tzlocal, tokenizers, shortuuid, setproctitle, sentry-sdk, semver, pympler, pydeck, pyarrow, promise, pathtools, multiprocess, importlib-metadata, huggingface-hub, GitPython, docker-pycreds, blinker, base58, astor, altair, wrapt, wandb, transformers, tf-estimator-nightly, tensorflow-io-gcs-filesystem, tensorboard, streamlit, seqeval, sentencepiece, opt-einsum, libclang, keras-preprocessing, keras, h5py, google-pasta, gast, flatbuffers, datasets, astunparse, tensorflow, simpletransformers\n",
+      "  Attempting uninstall: tensorboard\n",
+      "    Found existing installation: tensorboard 2.6.0\n",
+      "    Uninstalling tensorboard-2.6.0:\n",
+      "      Successfully uninstalled tensorboard-2.6.0\n",
+      "Successfully installed GitPython-3.1.27 aiohttp-3.8.1 aiosignal-1.2.0 altair-4.2.0 astor-0.8.1 astunparse-1.6.3 async-timeout-4.0.2 backports.zoneinfo-0.2.1 base58-2.1.1 blinker-1.4 datasets-1.18.3 dill-0.3.4 docker-pycreds-0.4.0 flatbuffers-2.0 frozenlist-1.3.0 fsspec-2022.2.0 gast-0.5.3 gitdb-4.0.9 google-pasta-0.2.0 h5py-3.6.0 huggingface-hub-0.4.0 importlib-metadata-4.11.2 ipywidgets-7.6.5 jupyterlab-widgets-1.0.2 keras-2.8.0 keras-preprocessing-1.1.2 libclang-13.0.0 multidict-6.0.2 multiprocess-0.70.12.2 opt-einsum-3.3.0 pandas-1.4.1 pathtools-0.1.2 promise-2.3 pyarrow-7.0.0 pydeck-0.7.1 pympler-1.0.1 pytz-deprecation-shim-0.1.0.post0 semver-2.13.0 sentencepiece-0.1.96 sentry-sdk-1.5.6 seqeval-1.2.2 setproctitle-1.2.2 shortuuid-1.0.8 simpletransformers-0.63.4 smmap-5.0.0 streamlit-1.6.0 tensorboard-2.8.0 tensorflow-2.8.0 tensorflow-io-gcs-filesystem-0.24.0 termcolor-1.1.0 tf-estimator-nightly-2.8.0.dev2021122109 tokenizers-0.11.6 toolz-0.11.2 transformers-4.16.2 tzdata-2021.5 tzlocal-4.1 validators-0.18.2 wandb-0.12.11 watchdog-2.1.6 widgetsnbextension-3.5.2 wrapt-1.13.3 xxhash-3.0.0 yarl-1.7.2 yaspin-2.1.0 zipp-3.7.0\n",
+      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
+     ]
+    }
+   ],
    "source": [
     "!pip install simpletransformers tensorflow\n",
-    "!pip install tensorboardx"
+    "# !pip install tensorboardx"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {
     "id": "RJC8wj73Zd_p"
    },
@@ -61,7 +403,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/"
@@ -69,7 +411,15 @@
     "id": "bsX3b7ZNYVZe",
     "outputId": "845660e8-c68b-4a52-d9ce-3c06bf7356d8"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Cuda available?  True\n"
+     ]
+    }
+   ],
    "source": [
     "# prepare logger\n",
     "logging.basicConfig(level=logging.INFO)\n",
@@ -85,16 +435,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "!pip install tensorflow"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/"
@@ -102,7 +443,30 @@
     "id": "HpRLLRzkTwdL",
     "outputId": "9dc072ea-e419-4bc1-ad99-507cdd4e1394"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Found GPU at: /device:GPU:0\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2022-03-02 10:07:29.225077: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+      "2022-03-02 10:07:29.226362: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-03-02 10:07:29.227434: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-03-02 10:07:29.228165: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-03-02 10:07:34.041835: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-03-02 10:07:34.042377: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-03-02 10:07:34.042908: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
+      "2022-03-02 10:07:34.043421: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /device:GPU:0 with 7001 MB memory:  -> device: 0, name: Quadro M4000, pci bus id: 0000:00:05.0, compute capability: 5.2\n"
+     ]
+    }
+   ],
    "source": [
     "if cuda_available:\n",
     "  import tensorflow as tf\n",
@@ -126,7 +490,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/"
@@ -134,7 +498,15 @@
     "id": "UW903YxwThrH",
     "outputId": "4dc91901-fa9f-446a-a883-dca331443d3d"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Fetching https://raw.githubusercontent.com/Perez-AlmendrosC/dontpatronizeme/master/semeval-2022/dont_patronize_me.py\n"
+     ]
+    }
+   ],
    "source": [
     "module_url = f\"https://raw.githubusercontent.com/Perez-AlmendrosC/dontpatronizeme/master/semeval-2022/dont_patronize_me.py\"\n",
     "module_name = module_url.split('/')[-1]\n",
@@ -147,7 +519,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 7,
    "metadata": {
     "id": "PRxm0179aqzw"
    },
@@ -162,7 +534,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "metadata": {
     "id": "gcDThFWVBxGb"
    },
@@ -173,7 +545,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 9,
    "metadata": {
     "id": "3Ay5_5Y0ThrI"
    },
@@ -184,7 +556,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 10,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/"
@@ -192,7 +564,16 @@
     "id": "2r3USK4eThrJ",
     "outputId": "53bbe18a-47df-4079-d28a-cf890c08b306"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Map of label to numerical label:\n",
+      "{'Unbalanced_power_relations': 0, 'Shallow_solution': 1, 'Presupposition': 2, 'Authority_voice': 3, 'Metaphors': 4, 'Compassion': 5, 'The_poorer_the_merrier': 6}\n"
+     ]
+    }
+   ],
    "source": [
     "dpm.load_task1()\n",
     "dpm.load_task2(return_one_hot=True)"
@@ -209,7 +590,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 11,
    "metadata": {
     "id": "8AReWYHYOUqx"
    },
@@ -221,7 +602,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 12,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/",
@@ -230,14 +611,83 @@
     "id": "a-_ADoJAOWJA",
     "outputId": "85dbe757-4ee5-4887-deac-60185515e141"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>par_id</th>\n",
+       "      <th>label</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>4341</td>\n",
+       "      <td>[1, 0, 0, 1, 0, 0, 0]</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>4136</td>\n",
+       "      <td>[0, 1, 0, 0, 0, 0, 0]</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>10352</td>\n",
+       "      <td>[1, 0, 0, 0, 0, 1, 0]</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>8279</td>\n",
+       "      <td>[0, 0, 0, 1, 0, 0, 0]</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>1164</td>\n",
+       "      <td>[1, 0, 0, 1, 1, 1, 0]</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   par_id                  label\n",
+       "0    4341  [1, 0, 0, 1, 0, 0, 0]\n",
+       "1    4136  [0, 1, 0, 0, 0, 0, 0]\n",
+       "2   10352  [1, 0, 0, 0, 0, 1, 0]\n",
+       "3    8279  [0, 0, 0, 1, 0, 0, 0]\n",
+       "4    1164  [1, 0, 0, 1, 1, 1, 0]"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trids.head()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "metadata": {
     "id": "7IfCZjwQ16MS"
    },
@@ -260,7 +710,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 14,
    "metadata": {
     "id": "BOxDR1H2g_3p"
    },
@@ -283,7 +733,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "metadata": {
     "id": "8e3E08Yown5p"
    },
@@ -294,27 +744,126 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 17,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>par_id</th>\n",
+       "      <th>text</th>\n",
+       "      <th>label</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>4341</td>\n",
+       "      <td>The scheme saw an estimated 150,000 children f...</td>\n",
+       "      <td>1</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>4136</td>\n",
+       "      <td>Durban 's homeless communities reconciliation ...</td>\n",
+       "      <td>1</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>10352</td>\n",
+       "      <td>The next immediate problem that cropped up was...</td>\n",
+       "      <td>1</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>8279</td>\n",
+       "      <td>Far more important than the implications for t...</td>\n",
+       "      <td>1</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>1164</td>\n",
+       "      <td>To strengthen child-sensitive social protectio...</td>\n",
+       "      <td>1</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "  par_id                                               text  label\n",
+       "0   4341  The scheme saw an estimated 150,000 children f...      1\n",
+       "1   4136  Durban 's homeless communities reconciliation ...      1\n",
+       "2  10352  The next immediate problem that cropped up was...      1\n",
+       "3   8279  Far more important than the implications for t...      1\n",
+       "4   1164  To strengthen child-sensitive social protectio...      1"
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trdf1.head()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 18,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "8375"
+      ]
+     },
+     "execution_count": 18,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trdf1.shape[0]"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 19,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0    7581\n",
+       "1     794\n",
+       "Name: label, dtype: int64"
+      ]
+     },
+     "execution_count": 19,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trdf1[\"label\"].value_counts()"
    ]
@@ -619,6 +1168,165 @@
     "print(sent)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Word embedding replacement"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
+      "Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (1.21.2)\n",
+      "Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (2.26.0)\n",
+      "Requirement already satisfied: nlpaug in /opt/conda/lib/python3.8/site-packages (1.1.10)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests) (1.26.7)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests) (2021.5.30)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests) (2.0.0)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests) (3.1)\n",
+      "Requirement already satisfied: pandas>=1.2.0 in /opt/conda/lib/python3.8/site-packages (from nlpaug) (1.4.1)\n",
+      "Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/lib/python3.8/site-packages (from pandas>=1.2.0->nlpaug) (2.8.2)\n",
+      "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.8/site-packages (from pandas>=1.2.0->nlpaug) (2021.3)\n",
+      "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.8/site-packages (from python-dateutil>=2.8.1->pandas>=1.2.0->nlpaug) (1.16.0)\n",
+      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
+      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
+      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
+      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install numpy requests nlpaug\n",
+    "!pip install torch>=1.6.0 transformers>=4.11.3 sentencepiece\n",
+    "!pip install nltk>=3.4.5\n",
+    "!pip install gensim>=4.1.2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import nlpaug.augmenter.word as naw\n",
+    "import nlpaug.augmenter.sentence as nas "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<module 'torch.cuda' from '/opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py'>\n"
+     ]
+    }
+   ],
+   "source": [
+    "device = torch.cuda"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "wordEmbeddingModel = naw.ContextualWordEmbsAug(\n",
+    "    model_path='bert-base-uncased', action=\"substitute\", \n",
+    "    aug_p=0.2, device='cuda',\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "par_id  text                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    label\n",
+       "1       We 're living in times of absolute insanity , as I 'm pretty sure most people are aware . For a while , waking up every day to check the news seemed to carry with it the same feeling of panic and dread that action heroes probably face when they 're trying to decide whether to cut the blue or green wire on a ticking bomb -- except the bomb 's instructions long ago burned in a fire and imminent catastrophe seems the likeliest outcome . It 's hard to stay that on-edge for that long , though , so it 's natural for people to become inured to this constant chaos , to slump into a malaise of hopelessness and pessimism .            0        1\n",
+       "6149    \"VIENTIANE , LAO PDR , May 17 -- Nine labour ministers of the Association of South East Asian Nations ( ASEAN ) , assembled in this capital city of land-locked Lao People 's Democratic Republic for their 24th biennial meeting , unanimously adopted the proposal of the Philippines to finalize by September 2016 , at the earliest , or by April 2017 , at the latest , the draft ASEAN instrument on the protection and promotion of the rights of migrant workers . \"\" This is a breakthrough in the negotiations for the instrument , more than 85 percent of which is already finished , \"\" Baldoz added.At the meeting , Bald ... Read more\"  0        1\n",
+       "6143    \"In an email sent by To on September 14 and sent to HKFP by the third student , To told the class : \"\" My intention was to give an example of how internet information needs to be verified in light of being credible and ... how the judge may look at it in the eyes of the vulnerable . \"\"\"                                                                                                                                                                                                                                                                                                                                                         0        1\n",
+       "6144    A hospital in Bangladesh near the Burmese border reported that refugees were arriving with bullet wounds , and the country plans to open another refugee camp to ease pressure on one that already has 50,000 inhabitants .                                                                                                                                                                                                                                                                                                                                                                                                                             0        1\n",
+       "6145    \"\"\" Women in need deserve laws that are in the best interest of their physical and emotional well-being , and that take into consideration their unborn child , \"\" added Aden .\"                                                                                                                                                                                                                                                                                                                                                                                                                                                                        0        1\n",
+       "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        ..\n",
+       "3533    LONDON -- In a cramped apartment in an industrial zone in south London , Sandra Rumkiene recounts her struggles to bring up a baby as one of a growing number of poor families forced to live in temporary housing .                                                                                                                                                                                                                                                                                                                                                                                                                                    0        1\n",
+       "3532    Help is yet to come to residents of Zabzugu in the Northern Region , a week after a heavy downpour rendered them homeless .                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             0        1\n",
+       "3531    A country ? ? ? s core economic aim should be to achieve an increasingly prosperous middle class . Debt allows wealthier , often older , households to lend money to less affluent , often younger , households . This allows meaningful wealth accumulations among huge numbers of previously poor families as home ownership provokes savings .                                                                                                                                                                                                                                                                                                       0        1\n",
+       "3530    According to a new report by Statistics Canada , close to 30 per cent of Inuit children across Canada live in homes in need of major repair , compared to less than eight per cent of non-Aboriginal children . ( Katherine Barton/CBC )                                                                                                                                                                                                                                                                                                                                                                                                                0        1\n",
+       "999     Which leads us to the other side of the coin , and an area economists and demographers in Australia need to watch closely , given the uneasy rise in anti-immigration sentiment . Can boundless immigration continue to be used to prop up an economy that requires constant growth to prop up its housing ponzi ? Which begs the greater question -- what if the immigrants do n't come to fill the gap ?                                                                                                                                                                                                                                              0        1\n",
+       "Length: 9169, dtype: int64"
+      ]
+     },
+     "execution_count": 33,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trdf1_word_embedding_subst = trdf1.copy()\n",
+    "pat_sent = trdf1.loc[trdf1['label'] == 1]\n",
+    "for i in range(1):\n",
+    "    print(f\"Epoch {i}\")\n",
+    "    # to potentially break this down into batches -> very slow! 1 run takes 17m smh\n",
+    "    pat_sent_synonym = pat_sent.copy()\n",
+    "    pat_sent_synonym['text'] = pat_sent_synonym['text'].apply(lambda x: wordEmbeddingModel.augment(x))\n",
+    "    trdf1_word_embedding_subst = pd.concat([trdf1_word_embedding_subst, pat_sent_synonym], ignore_index=True)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "par_id  text                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    label\n",
+       "1       We 're living in times of absolute insanity , as I 'm pretty sure most people are aware . For a while , waking up every day to check the news seemed to carry with it the same feeling of panic and dread that action heroes probably face when they 're trying to decide whether to cut the blue or green wire on a ticking bomb -- except the bomb 's instructions long ago burned in a fire and imminent catastrophe seems the likeliest outcome . It 's hard to stay that on-edge for that long , though , so it 's natural for people to become inured to this constant chaos , to slump into a malaise of hopelessness and pessimism .            0        1\n",
+       "6149    \"VIENTIANE , LAO PDR , May 17 -- Nine labour ministers of the Association of South East Asian Nations ( ASEAN ) , assembled in this capital city of land-locked Lao People 's Democratic Republic for their 24th biennial meeting , unanimously adopted the proposal of the Philippines to finalize by September 2016 , at the earliest , or by April 2017 , at the latest , the draft ASEAN instrument on the protection and promotion of the rights of migrant workers . \"\" This is a breakthrough in the negotiations for the instrument , more than 85 percent of which is already finished , \"\" Baldoz added.At the meeting , Bald ... Read more\"  0        1\n",
+       "6143    \"In an email sent by To on September 14 and sent to HKFP by the third student , To told the class : \"\" My intention was to give an example of how internet information needs to be verified in light of being credible and ... how the judge may look at it in the eyes of the vulnerable . \"\"\"                                                                                                                                                                                                                                                                                                                                                         0        1\n",
+       "6144    A hospital in Bangladesh near the Burmese border reported that refugees were arriving with bullet wounds , and the country plans to open another refugee camp to ease pressure on one that already has 50,000 inhabitants .                                                                                                                                                                                                                                                                                                                                                                                                                             0        1\n",
+       "6145    \"\"\" Women in need deserve laws that are in the best interest of their physical and emotional well-being , and that take into consideration their unborn child , \"\" added Aden .\"                                                                                                                                                                                                                                                                                                                                                                                                                                                                        0        1\n",
+       "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        ..\n",
+       "3533    LONDON -- In a cramped apartment in an industrial zone in south London , Sandra Rumkiene recounts her struggles to bring up a baby as one of a growing number of poor families forced to live in temporary housing .                                                                                                                                                                                                                                                                                                                                                                                                                                    0        1\n",
+       "3532    Help is yet to come to residents of Zabzugu in the Northern Region , a week after a heavy downpour rendered them homeless .                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             0        1\n",
+       "3531    A country ? ? ? s core economic aim should be to achieve an increasingly prosperous middle class . Debt allows wealthier , often older , households to lend money to less affluent , often younger , households . This allows meaningful wealth accumulations among huge numbers of previously poor families as home ownership provokes savings .                                                                                                                                                                                                                                                                                                       0        1\n",
+       "3530    According to a new report by Statistics Canada , close to 30 per cent of Inuit children across Canada live in homes in need of major repair , compared to less than eight per cent of non-Aboriginal children . ( Katherine Barton/CBC )                                                                                                                                                                                                                                                                                                                                                                                                                0        1\n",
+       "999     Which leads us to the other side of the coin , and an area economists and demographers in Australia need to watch closely , given the uneasy rise in anti-immigration sentiment . Can boundless immigration continue to be used to prop up an economy that requires constant growth to prop up its housing ponzi ? Which begs the greater question -- what if the immigrants do n't come to fill the gap ?                                                                                                                                                                                                                                              0        1\n",
+       "Length: 9169, dtype: int64"
+      ]
+     },
+     "execution_count": 35,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trdf1_word_embedding_subst.value_counts()"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {
@@ -630,7 +1338,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 16,
    "metadata": {
     "id": "T6FLgB6KxGI2"
    },
@@ -653,7 +1361,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 17,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/"
@@ -661,14 +1369,25 @@
     "id": "YbB9GdzJxRAH",
     "outputId": "c78e311e-9502-4644-b6f7-0c64f64aa66f"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "2094"
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "len(rows)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 18,
    "metadata": {
     "id": "vhBhTRIyxSaQ"
    },
@@ -688,7 +1407,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 19,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/"
@@ -707,7 +1426,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 20,
    "metadata": {
     "colab": {
      "base_uri": "https://localhost:8080/",
@@ -716,7 +1435,20 @@
     "id": "mpSqMp3d8iYu",
     "outputId": "037d4f45-eab5-4f04-e9a5-1aa64c46323d"
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0    1270\n",
+      "1     635\n",
+      "Name: label, dtype: int64\n",
+      "0    318\n",
+      "1    159\n",
+      "Name: label, dtype: int64\n"
+     ]
+    }
+   ],
    "source": [
     "non_pat_set1 = train_set1.loc[train_set1['label'] == 0]\n",
     "non_split = int(0.8 * len(non_pat_set1))\n",
@@ -733,14 +1465,268 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "precision: 0.6875, recall: 0.6875, f1: 0.6875\n"
+     ]
+    }
+   ],
+   "source": [
+    "def precision(tp, fp):\n",
+    "    return tp / (tp + fp)\n",
+    "\n",
+    "def recall(tp, fn):\n",
+    "    return tp / (tp + fn)\n",
+    "\n",
+    "def f1(precision, recall):\n",
+    "    return 2 * precision * recall / (precision + recall)\n",
+    "\n",
+    "p = precision(110, 50)\n",
+    "r = recall(110, 50)\n",
+    "f_score = f1(p, r)\n",
+    "\n",
+    "print(f\"precision: {p}, recall: {r}, f1: {f_score}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
    "metadata": {},
    "outputs": [],
+   "source": [
+    "torch.cuda.empty_cache()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "DEBUG:filelock:Attempting to acquire lock 140191560837776 on /root/.cache/huggingface/transformers/733bade19e5f0ce98e6531021dd5180994bb2f7b8bd7e80c7968805834ba351e.35205c6cfc956461d8515139f0f8dd5d207a2f336c0c3a83b4bc8dca3518e37b.lock\n",
+      "DEBUG:filelock:Lock 140191560837776 acquired on /root/.cache/huggingface/transformers/733bade19e5f0ce98e6531021dd5180994bb2f7b8bd7e80c7968805834ba351e.35205c6cfc956461d8515139f0f8dd5d207a2f336c0c3a83b4bc8dca3518e37b.lock\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "37d78efeefea49d3a4722fbdcbabda77",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "DEBUG:filelock:Attempting to release lock 140191560837776 on /root/.cache/huggingface/transformers/733bade19e5f0ce98e6531021dd5180994bb2f7b8bd7e80c7968805834ba351e.35205c6cfc956461d8515139f0f8dd5d207a2f336c0c3a83b4bc8dca3518e37b.lock\n",
+      "DEBUG:filelock:Lock 140191560837776 released on /root/.cache/huggingface/transformers/733bade19e5f0ce98e6531021dd5180994bb2f7b8bd7e80c7968805834ba351e.35205c6cfc956461d8515139f0f8dd5d207a2f336c0c3a83b4bc8dca3518e37b.lock\n",
+      "DEBUG:filelock:Attempting to acquire lock 140186325815696 on /root/.cache/huggingface/transformers/51ba668f7ff34e7cdfa9561e8361747738113878850a7d717dbc69de8683aaad.c7efaa30a0d80b2958b876969faa180e485944a849deee4ad482332de65365a7.lock\n",
+      "DEBUG:filelock:Lock 140186325815696 acquired on /root/.cache/huggingface/transformers/51ba668f7ff34e7cdfa9561e8361747738113878850a7d717dbc69de8683aaad.c7efaa30a0d80b2958b876969faa180e485944a849deee4ad482332de65365a7.lock\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8db0da33eceb4a06bda678613430eaa1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "DEBUG:filelock:Attempting to release lock 140186325815696 on /root/.cache/huggingface/transformers/51ba668f7ff34e7cdfa9561e8361747738113878850a7d717dbc69de8683aaad.c7efaa30a0d80b2958b876969faa180e485944a849deee4ad482332de65365a7.lock\n",
+      "DEBUG:filelock:Lock 140186325815696 released on /root/.cache/huggingface/transformers/51ba668f7ff34e7cdfa9561e8361747738113878850a7d717dbc69de8683aaad.c7efaa30a0d80b2958b876969faa180e485944a849deee4ad482332de65365a7.lock\n",
+      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight']\n",
+      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
+      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
+      "DEBUG:filelock:Attempting to acquire lock 140186326952640 on /root/.cache/huggingface/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock\n",
+      "DEBUG:filelock:Lock 140186326952640 acquired on /root/.cache/huggingface/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "08feb2e7595c417ba050f0ef49737b1f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "DEBUG:filelock:Attempting to release lock 140186326952640 on /root/.cache/huggingface/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock\n",
+      "DEBUG:filelock:Lock 140186326952640 released on /root/.cache/huggingface/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab.lock\n",
+      "DEBUG:filelock:Attempting to acquire lock 140186326953072 on /root/.cache/huggingface/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock\n",
+      "DEBUG:filelock:Lock 140186326953072 acquired on /root/.cache/huggingface/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "99f4ee1482e74b6da2c30615d71e777f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "DEBUG:filelock:Attempting to release lock 140186326953072 on /root/.cache/huggingface/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock\n",
+      "DEBUG:filelock:Lock 140186326953072 released on /root/.cache/huggingface/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock\n",
+      "DEBUG:filelock:Attempting to acquire lock 140186326632528 on /root/.cache/huggingface/transformers/d53fc0fa09b8342651efd4073d75e19617b3e51287c2a535becda5808a8db287.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock\n",
+      "DEBUG:filelock:Lock 140186326632528 acquired on /root/.cache/huggingface/transformers/d53fc0fa09b8342651efd4073d75e19617b3e51287c2a535becda5808a8db287.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f87e2940eb07436e97515641b144f34d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "DEBUG:filelock:Attempting to release lock 140186326632528 on /root/.cache/huggingface/transformers/d53fc0fa09b8342651efd4073d75e19617b3e51287c2a535becda5808a8db287.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock\n",
+      "DEBUG:filelock:Lock 140186326632528 released on /root/.cache/huggingface/transformers/d53fc0fa09b8342651efd4073d75e19617b3e51287c2a535becda5808a8db287.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock\n",
+      "/opt/conda/lib/python3.8/site-packages/simpletransformers/classification/classification_model.py:585: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels.\n",
+      "  warnings.warn(\n",
+      "INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "491c750115804e81919d6af748f95994",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1905 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
+      "  warnings.warn(\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "74fca7da98d948a4b96b03a3e704a448",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Epoch:   0%|          | 0/5 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e1e497570a184a72893e365905354672",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Running Epoch 0 of 5:   0%|          | 0/15 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "ename": "RuntimeError",
+     "evalue": "CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 7.94 GiB total capacity; 7.22 GiB already allocated; 12.19 MiB free; 7.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m/tmp/ipykernel_412/2705655345.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m     \u001b[0mhyperparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train_batch_size\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m     \u001b[0mtask1_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"roberta\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"roberta-base\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining_set1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_set1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhyperparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m     \u001b[0mpreds_task1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtask1_model\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m     \u001b[0;31m# all_preds.append(preds_task1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/tmp/ipykernel_412/2705655345.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, model_name, train, val, hyperparams)\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     37\u001b[0m     \u001b[0;31m# train model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m     \u001b[0mtask1_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'text'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_df\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'text'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     39\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtask1_model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/simpletransformers/classification/classification_model.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(self, train_df, multi_label, output_dir, show_running_loss, args, eval_df, verbose, **kwargs)\u001b[0m\n\u001b[1;32m    603\u001b[0m         \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmakedirs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexist_ok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    604\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 605\u001b[0;31m         global_step, training_details = self.train(\n\u001b[0m\u001b[1;32m    606\u001b[0m             \u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    607\u001b[0m             \u001b[0moutput_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/simpletransformers/classification/classification_model.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, train_dataloader, output_dir, multi_label, show_running_loss, eval_df, test_df, verbose, **kwargs)\u001b[0m\n\u001b[1;32m    878\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    879\u001b[0m                     \u001b[0;32mwith\u001b[0m \u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 880\u001b[0;31m                         loss, *_ = self._calculate_loss(\n\u001b[0m\u001b[1;32m    881\u001b[0m                             \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    882\u001b[0m                             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/simpletransformers/classification/classification_model.py\u001b[0m in \u001b[0;36m_calculate_loss\u001b[0;34m(self, model, inputs, loss_fct, num_labels, args)\u001b[0m\n\u001b[1;32m   2283\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2284\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_calculate_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fct\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2285\u001b[0;31m         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2286\u001b[0m         \u001b[0;31m# model outputs are always tuple in pytorch-transformers (see doc)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2287\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1202\u001b[0m         \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_return_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1204\u001b[0;31m         outputs = self.roberta(\n\u001b[0m\u001b[1;32m   1205\u001b[0m             \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1206\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    848\u001b[0m             \u001b[0mpast_key_values_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpast_key_values_length\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    849\u001b[0m         )\n\u001b[0;32m--> 850\u001b[0;31m         encoder_outputs = self.encoder(\n\u001b[0m\u001b[1;32m    851\u001b[0m             \u001b[0membedding_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    852\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mextended_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    524\u001b[0m                 )\n\u001b[1;32m    525\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 526\u001b[0;31m                 layer_outputs = layer_module(\n\u001b[0m\u001b[1;32m    527\u001b[0m                     \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    528\u001b[0m                     \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    410\u001b[0m         \u001b[0;31m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    411\u001b[0m         \u001b[0mself_attn_past_key_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpast_key_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 412\u001b[0;31m         self_attention_outputs = self.attention(\n\u001b[0m\u001b[1;32m    413\u001b[0m             \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    414\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    337\u001b[0m         \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    338\u001b[0m     ):\n\u001b[0;32m--> 339\u001b[0;31m         self_outputs = self.self(\n\u001b[0m\u001b[1;32m    340\u001b[0m             \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    341\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/transformers/models/roberta/modeling_roberta.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    267\u001b[0m         \u001b[0;31m# This is actually dropping out entire tokens to attend to, which might\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    268\u001b[0m         \u001b[0;31m# seem a bit unusual, but is taken from the original Transformer paper.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 269\u001b[0;31m         \u001b[0mattention_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    270\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    271\u001b[0m         \u001b[0;31m# Mask heads if we want to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/modules/dropout.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mdropout\u001b[0;34m(input, p, training, inplace)\u001b[0m\n\u001b[1;32m   1167\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0.0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1168\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dropout probability has to be between 0 and 1, \"\u001b[0m \u001b[0;34m\"but got {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1169\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1171\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 7.94 GiB total capacity; 7.22 GiB already allocated; 12.19 MiB free; 7.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
+     ]
+    }
+   ],
    "source": [
     "hyperparams = {\n",
-    "    \"learning_rate\": 0.0005,\n",
+    "    \"learning_rate\": 5e-5,\n",
     "    \"train_batch_size\": 8,\n",
-    "    \"num_train_epochs\": 1,\n",
+    "    \"num_train_epochs\": 5,\n",
     "    \"optimizer\": \"AdamW\", # Adafactor\n",
     "    \"scheduler\":\"linear_schedule_with_warmup\",\n",
     "    \"evaluate_during_training\": True,\n",
@@ -748,6 +1734,7 @@
     "}\n",
     "\n",
     "def preprocess(data, use_synonyms=False, use_embedding=False, use_translate=False):\n",
+    "\n",
     "    if use_synonyms:\n",
     "        data = apply_synonyms(data)\n",
     "    if use_embedding:\n",
@@ -769,7 +1756,7 @@
     "                                    model_name, \n",
     "                                    args = task1_model_args, \n",
     "                                    num_labels=2, \n",
-    "                                    use_cuda=cuda_available,)\n",
+    "                                    use_cuda=cuda_available)\n",
     "\n",
     "    \n",
     "    # train model\n",
@@ -781,8 +1768,92 @@
     "    preds_task1, _ = task1_model.predict(tedf1.text.tolist())\n",
     "    return preds_task1\n",
     "\n",
-    "task1_model = train_model(\"roberta\", \"roberta-base\", training_set1, val_set1, hyperparams)\n",
-    "preds_task1 = test_model(task1_model)"
+    "# all_preds = []\n",
+    "for bs in [128]:\n",
+    "    hyperparams[\"train_batch_size\"] = bs \n",
+    "    task1_model = train_model(\"roberta\", \"roberta-base\", training_set1, val_set1, hyperparams)\n",
+    "    preds_task1 = test_model(task1_model)\n",
+    "    # all_preds.append(preds_task1)\n",
+    "    del task1_model\n",
+    "    torch.cuda.empty_cache()\n",
+    "    torch.cuda.synchronize()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "torch.cuda.empty_cache()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'hyperparams' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m/tmp/ipykernel_2545/189035120.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mall_preds_scheduler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mscheduler\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"linear_schedule_with_warmup\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"polynomial_decay_schedule_with_warmup\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"constant_schedule_with_warmup\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mhyperparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"scheduler\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscheduler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0mtask1_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"roberta\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"roberta-base\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining_set1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_set1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhyperparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mpreds_task1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtask1_model\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'hyperparams' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "all_preds_scheduler = []\n",
+    "for scheduler in [\"linear_schedule_with_warmup\", \"polynomial_decay_schedule_with_warmup\", \"constant_schedule_with_warmup\"]:\n",
+    "    hyperparams[\"scheduler\"] = scheduler\n",
+    "    task1_model = train_model(\"roberta\", \"roberta-base\", training_set1, val_set1, hyperparams)\n",
+    "    preds_task1 = test_model(task1_model)\n",
+    "    all_preds_scheduler.append(preds_task1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "410"
+      ]
+     },
+     "execution_count": 45,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(preds_task1[preds_task1 == 1])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0    1895\n",
+       "1     199\n",
+       "Name: label, dtype: int64"
+      ]
+     },
+     "execution_count": 48,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "tedf1[\"label\"].value_counts()"
    ]
   },
   {