{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib ipympl\n",
    "import ipywidgets as widgets\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib.colors import TABLEAU_COLORS, XKCD_COLORS, to_rgba_array\n",
    "\n",
    "import mpl_interactions.ipyplot as iplt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "N = 50\n",
    "x = np.random.rand(N)\n",
    "\n",
    "\n",
    "def f_y(x, tau):\n",
    "    return np.sin(x * tau) ** 2 + np.random.randn(N) * 0.01\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "controls = iplt.scatter(x, f_y, tau=(1, 2 * np.pi, 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](../_static/images/scatter1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using functions and broadcasting\n",
    "You can also use multiple functions. If there are fewer `x` inputs than `y` inputs then the `x` input will be broadcast to fit the `y` inputs. Similarly `y` inputs can be broadcast to fit `x`. You can also choose colors and sizes for each line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "N = 50\n",
    "x = np.random.rand(N)\n",
    "\n",
    "\n",
    "def f_y1(x, tau):\n",
    "    return np.sin(x * tau) ** 2 + np.random.randn(N) * 0.01\n",
    "\n",
    "\n",
    "def f_y2(x, tau):\n",
    "    return np.cos(x * tau) ** 2 + np.random.randn(N) * 0.1\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "controls = iplt.scatter(x, f_y1, tau=(1, 2 * np.pi, 100), c=\"blue\", s=5)\n",
    "_ = iplt.scatter(x, f_y2, controls=controls, c=\"red\", s=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](../_static/images/scatter2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions for both x and y\n",
    "\n",
    "The function for `y` should accept `x` and then any parameters that you will be varying. The function for `x` should accept only the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "N = 50\n",
    "\n",
    "\n",
    "def f_x(mean):\n",
    "    return np.random.rand(N) + mean\n",
    "\n",
    "\n",
    "def f_y(x, mean):\n",
    "    return np.random.rand(N) - mean\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "controls = iplt.scatter(f_x, f_y, mean=(0, 1, 100), s=None, c=np.random.randn(N))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](../_static/images/scatter3.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using functions for other attributes\n",
    "\n",
    "You can also use functions to dynamically update other scatter attributes such as the `size`, `color`, `edgecolor`, and `alpha`.\n",
    "\n",
    "The function for `alpha` needs to accept the parameters but not the xy positions as it affects every point. The functions for `size`, `color` and `edgecolor` all should accept `x, y, <rest of parameters>`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "N = 50\n",
    "mean = 0\n",
    "x = np.random.rand(N) + mean - 0.5\n",
    "\n",
    "\n",
    "def f(x, mean):\n",
    "    return np.random.rand(N) + mean - 0.5\n",
    "\n",
    "\n",
    "def c_func(x, y, mean):\n",
    "    return x\n",
    "\n",
    "\n",
    "def s_func(x, y, mean):\n",
    "    return np.abs(40 / (x + 0.001))\n",
    "\n",
    "\n",
    "def ec_func(x, y, mean):\n",
    "    if np.random.rand() > 0.5:\n",
    "        return \"black\"\n",
    "    else:\n",
    "        return \"red\"\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "sliders = iplt.scatter(\n",
    "    x,\n",
    "    f,\n",
    "    mean=(0, 1, 100),\n",
    "    c=c_func,\n",
    "    s=s_func,\n",
    "    edgecolors=ec_func,\n",
    "    alpha=0.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](../_static/images/scatter4.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modifying the colors of individual points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "N = 500\n",
    "x = np.random.rand(N) - 0.5\n",
    "y = np.random.rand(N) - 0.5\n",
    "\n",
    "\n",
    "def f(mean):\n",
    "    x = (np.random.rand(N) - 0.5) + mean\n",
    "    y = 10 * (np.random.rand(N) - 0.5) + mean\n",
    "    return x, y\n",
    "\n",
    "\n",
    "def threshold(x, y, mean):\n",
    "    colors = np.zeros((len(x), 4))\n",
    "    colors[:, -1] = 1\n",
    "    deltas = np.abs(y - mean)\n",
    "    idx = deltas < 0.01\n",
    "    deltas /= deltas.max()\n",
    "    colors[~idx, -1] = np.clip(0.8 - deltas[~idx], 0, 1)\n",
    "    return colors\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "sliders = iplt.scatter(x, y, mean=(0, 1, 100), alpha=None, c=threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](../_static/images/scatter5.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Putting it together - Wealth of Nations\n",
    "Using interactive_scatter we can recreate the interactive [wealth of nations](https://observablehq.com/@mbostock/the-wealth-health-of-nations) plot using Matplotlib!\n",
    "\n",
    "\n",
    "The data preprocessing was taken from an [example notebook](https://github.com/bqplot/bqplot/blob/55152feb645b523faccb97ea4083ca505f26f6a2/examples/Applications/Wealth%20Of%20Nations/Bubble%20Chart.ipynb) from the [bqplot](https://github.com/bqplot/bqplot) library. If you are working in jupyter notebooks then you should definitely check out bqplot!\n",
    "\n",
    "\n",
    "### Data preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this cell was taken wholesale from the bqplot example\n",
    "# bqplot is under the apache license, see their license file here:\n",
    "# https://github.com/bqplot/bqplot/blob/55152feb645b523faccb97ea4083ca505f26f6a2/LICENSE\n",
    "data = pd.read_json(\"nations.json\")\n",
    "\n",
    "\n",
    "def clean_data(data):\n",
    "    for column in [\"income\", \"lifeExpectancy\", \"population\"]:\n",
    "        data = data.drop(data[data[column].apply(len) <= 4].index)\n",
    "    return data\n",
    "\n",
    "\n",
    "def extrap_interp(data):\n",
    "    data = np.array(data)\n",
    "    x_range = np.arange(1800, 2009, 1.0)\n",
    "    y_range = np.interp(x_range, data[:, 0], data[:, 1])\n",
    "    return y_range\n",
    "\n",
    "\n",
    "def extrap_data(data):\n",
    "    for column in [\"income\", \"lifeExpectancy\", \"population\"]:\n",
    "        data[column] = data[column].apply(extrap_interp)\n",
    "    return data\n",
    "\n",
    "\n",
    "data = clean_data(data)\n",
    "data = extrap_data(data)\n",
    "income_min, income_max = np.min(data[\"income\"].apply(np.min)), np.max(data[\"income\"].apply(np.max))\n",
    "life_exp_min, life_exp_max = np.min(data[\"lifeExpectancy\"].apply(np.min)), np.max(\n",
    "    data[\"lifeExpectancy\"].apply(np.max)\n",
    ")\n",
    "pop_min, pop_max = np.min(data[\"population\"].apply(np.min)), np.max(\n",
    "    data[\"population\"].apply(np.max)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define functions to provide the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def x(year):\n",
    "    return data[\"income\"].apply(lambda x: x[year - 1800])\n",
    "\n",
    "\n",
    "def y(x, year):\n",
    "    return data[\"lifeExpectancy\"].apply(lambda x: x[year - 1800])\n",
    "\n",
    "\n",
    "def s(x, y, year):\n",
    "    pop = data[\"population\"].apply(lambda x: x[year - 1800])\n",
    "    return 6000 * pop.values / pop_max\n",
    "\n",
    "\n",
    "regions = data[\"region\"].unique().tolist()\n",
    "c = data[\"region\"].apply(lambda x: list(TABLEAU_COLORS)[regions.index(x)]).values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Marvel at data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(10, 4.8))\n",
    "controls = iplt.scatter(\n",
    "    x,\n",
    "    y,\n",
    "    s=s,\n",
    "    year=np.arange(1800, 2009),\n",
    "    c=c,\n",
    "    edgecolors=\"k\",\n",
    "    slider_formats=\"{:d}\",\n",
    "    play_buttons=True,\n",
    "    play_button_pos=\"left\",\n",
    ")\n",
    "fs = 15\n",
    "ax.set_xscale(\"log\")\n",
    "ax.set_ylim([0, 100])\n",
    "ax.set_xlim([200, income_max * 1.05])\n",
    "ax.set_xlabel(\"Income\", fontsize=fs)\n",
    "_ = ax.set_ylabel(\"Life Expectancy\", fontsize=fs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](../_static/images/scatter6.png)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
