{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ICNN Dual Solver "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we explore how to learn the solution of the Kantorovich dual based on parameterizing the two dual potentials $f$ and $g$ with two input convex neural networks ({class}`~ott.solvers.nn.icnn.ICNN`) {cite}`amos:17`, a method developed by {cite}`makkuva:20`. For more insights on the approach itself, we refer the user to the original publication.\n",
"\n",
"Given dataloaders containing samples of the *source* and the *target* distribution, `OTT`'s {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` finds the pair of optimal potentials $f$ and $g$ to solve the corresponding dual of the optimal transport problem. Once a solution has been found, these neural {class}`~ott.problems.linear.potentials.DualPotentials` can be used to transport unseen source data samples to its target distribution (or vice-versa) or compute the corresponding distance between new source and target distribution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" !pip install -q git+https://github.com/ott-jax/ott@main"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"from torch.utils.data import DataLoader, IterableDataset\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.solvers.nn import icnn, neuraldual\n",
"from ott.tools import sinkhorn_divergence"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us define some helper functions which we use for the subsequent analysis."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def plot_ot_map(neural_dual, source, target, inverse=False):\n",
" \"\"\"Plot data and learned optimal transport map.\"\"\"\n",
"\n",
" def draw_arrows(a, b):\n",
" plt.arrow(\n",
" a[0], a[1], b[0] - a[0], b[1] - a[1], color=[0.5, 0.5, 1], alpha=0.3\n",
" )\n",
"\n",
" grad_state_s = neural_dual.transport(source, forward=not inverse)\n",
"\n",
" fig = plt.figure()\n",
" ax = fig.add_subplot(111)\n",
"\n",
" if not inverse:\n",
" ax.scatter(\n",
" target[:, 0],\n",
" target[:, 1],\n",
" color=\"#A7BED3\",\n",
" alpha=0.5,\n",
" label=r\"$target$\",\n",
" )\n",
" ax.scatter(\n",
" source[:, 0],\n",
" source[:, 1],\n",
" color=\"#1A254B\",\n",
" alpha=0.5,\n",
" label=r\"$source$\",\n",
" )\n",
" ax.scatter(\n",
" grad_state_s[:, 0],\n",
" grad_state_s[:, 1],\n",
" color=\"#F2545B\",\n",
" alpha=0.5,\n",
" label=r\"$\\nabla g(source)$\",\n",
" )\n",
" else:\n",
" ax.scatter(\n",
" target[:, 0],\n",
" target[:, 1],\n",
" color=\"#A7BED3\",\n",
" alpha=0.5,\n",
" label=r\"$source$\",\n",
" )\n",
" ax.scatter(\n",
" source[:, 0],\n",
" source[:, 1],\n",
" color=\"#1A254B\",\n",
" alpha=0.5,\n",
" label=r\"$target$\",\n",
" )\n",
" ax.scatter(\n",
" grad_state_s[:, 0],\n",
" grad_state_s[:, 1],\n",
" color=\"#F2545B\",\n",
" alpha=0.5,\n",
" label=r\"$\\nabla f(target)$\",\n",
" )\n",
"\n",
" plt.legend()\n",
"\n",
" for i in range(source.shape[0]):\n",
" draw_arrows(source[i, :], grad_state_s[i, :])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_optimizer(optimizer, lr, b1, b2, eps):\n",
" \"\"\"Returns a flax optimizer object based on `config`.\"\"\"\n",
"\n",
" if optimizer == \"Adam\":\n",
" optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2, eps=eps)\n",
" elif optimizer == \"SGD\":\n",
" optimizer = optax.sgd(learning_rate=lr, momentum=None, nesterov=False)\n",
" else:\n",
" raise NotImplementedError(f\"Optimizer {optimizer} not supported yet!\")\n",
"\n",
" return optimizer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def sinkhorn_loss(x, y, epsilon=0.1):\n",
" \"\"\"Computes transport between (x, a) and (y, b) via Sinkhorn algorithm.\"\"\"\n",
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence.sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Training and Validation Datasets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We apply the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` to compute the transport between toy datasets. In this tutorial, the user can choose between the datasets `simple` (data clustered in one center), `circle` (two-dimensional Gaussians arranged on a circle), `square_five` (two-dimensional Gaussians on a square with one Gaussian in the center), and `square_four` (two-dimensional Gaussians in the corners of a rectangle)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class ToyDataset(IterableDataset):\n",
" def __init__(self, name):\n",
" self.name = name\n",
"\n",
" def __iter__(self):\n",
" return self.create_sample_generators()\n",
"\n",
" def create_sample_generators(self, scale=5.0, variance=0.5):\n",
" # given name of dataset, select centers\n",
" if self.name == \"simple\":\n",
" centers = np.array([0, 0])\n",
"\n",
" elif self.name == \"circle\":\n",
" centers = np.array(\n",
" [\n",
" (1, 0),\n",
" (-1, 0),\n",
" (0, 1),\n",
" (0, -1),\n",
" (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
" (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
" (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
" (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
" ]\n",
" )\n",
"\n",
" elif self.name == \"square_five\":\n",
" centers = np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]])\n",
"\n",
" elif self.name == \"square_four\":\n",
" centers = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])\n",
"\n",
" else:\n",
" raise NotImplementedError()\n",
"\n",
" # create generator which randomly picks center and adds noise\n",
" centers = scale * centers\n",
" while True:\n",
" center = centers[np.random.choice(len(centers))]\n",
" point = center + variance**2 * np.random.randn(2)\n",
"\n",
" yield point\n",
"\n",
"\n",
"def load_toy_data(\n",
" name_source: str,\n",
" name_target: str,\n",
" batch_size: int = 1024,\n",
" valid_batch_size: int = 1024,\n",
"):\n",
" dataloaders = (\n",
" iter(DataLoader(ToyDataset(name_source), batch_size=batch_size)),\n",
" iter(DataLoader(ToyDataset(name_target), batch_size=batch_size)),\n",
" iter(DataLoader(ToyDataset(name_source), batch_size=valid_batch_size)),\n",
" iter(DataLoader(ToyDataset(name_target), batch_size=valid_batch_size)),\n",
" )\n",
" input_dim = 2\n",
" return dataloaders, input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solve Neural Dual"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to solve the neural dual, we need to define our dataloaders. The only requirement is that the corresponding source and target train and validation datasets are *iterators*."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"(dataloader_source, dataloader_target, _, _), input_dim = load_toy_data(\n",
" \"square_five\", \"square_four\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we define the architectures parameterizing the dual potentials $f$ and $g$. These need to be parameterized by {class}`~ott.solvers.nn.icnn.ICNN`s. You can adapt the size of the ICNNs by passing a sequence containing hidden layer sizes. While ICNNs are by default containing partially positive weights, we can run the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` using approximations to this positivity constraint (via weight clipping and a weight penalization). For this, set `positive_weights` to `True` in both the ICNN architecture and {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` configuration. For more details on how to customize the {class}`~ott.solvers.nn.icnn.ICNN` architectures, we refer you to the documentation."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# initialize models\n",
"neural_f = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n",
"neural_g = icnn.ICNN(dim_hidden=[64, 64, 64, 64], dim_data=2)\n",
"\n",
"# initialize optimizers\n",
"optimizer_f = get_optimizer(\"Adam\", lr=0.001, b1=0.5, b2=0.9, eps=1e-8)\n",
"optimizer_g = get_optimizer(\"Adam\", lr=0.001, b1=0.5, b2=0.9, eps=1e-8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We then initialize the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` by passing the two {class}`~ott.solvers.nn.icnn.ICNN` models parameterizing $f$ and $g$, as well as by specifying the input dimensions of the data and the number of training iterations to execute. Once the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver` is initialized, we can obtain the neural {class}`~ott.problems.linear.potentials.DualPotentials` by passing the corresponding dataloaders to it. As here our training and validation datasets do not differ, we pass (`dataloader_source`, `dataloader_target`) for both training and validation steps. For more details on how to configure the {class}`~ott.solvers.nn.neuraldual.NeuralDualSolver`, we refer you to the documentation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Execution of the following cell might take up to 15 minutes per 5000 iterations (depending on your system and the number of training iterations."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31b3dff5bd2840b0b358f91fdb2b117b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/15000 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"neural_dual_solver = neuraldual.NeuralDualSolver(\n",
" input_dim,\n",
" neural_f,\n",
" neural_g,\n",
" optimizer_f,\n",
" optimizer_g,\n",
" num_train_iters=15000,\n",
")\n",
"neural_dual = neural_dual_solver(\n",
" dataloader_source, dataloader_target, dataloader_source, dataloader_target\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate Neural Dual"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After training has completed successfully, we can evaluate the neural {class}`~ott.problems.linear.potentials.DualPotentials` on unseen incoming data. We first sample a new batch from the source and target distribution."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"data_source = next(dataloader_source).numpy()\n",
"data_target = next(dataloader_target).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can plot the corresponding transport from source to target using the gradient of the learning potential $g$, i.e., $\\nabla g(\\text{source})$, or from target to source via the gradient of the learning potential $f$, i.e., $\\nabla f(\\text{target})$."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAACdxElEQVR4nOy9d5xkZ3Xm/31vqlydp3ty1gTNSEIRiSyQSLIAE9cGjGXCLvZ6bbzG3vXau/55scE2ZoONDTbYeGEXGzAgkpBMBqGMpJFGk3Pq6Vy56ob398fbb99b1dVppmc6TD2fT3VV37p169YNz3vec55zjpBS0kILLbTQwtKFsdA70EILLbTQwsWhReQttNBCC0scLSJvoYUWWljiaBF5Cy200MISR4vIW2ihhRaWOKyF+NLu7m65YcOGhfjqFlpooYUli8cff3xQStnTuHxBiHzDhg089thjC/HVLbTQQgtLFkKI482Wt1wrLbTQQgtLHC0ib6GFFlpY4pgXIhdCtAshviiE2CeEeE4Icet8bLeFFlpooYWZMV8+8v8J3CelfJMQwgGS87TdFlpoYZHAdV1OnTpFpVJZ6F1Z9ojH46xZswbbtme1/kUTuRCiDXgx8C4AKWUNqF3sdltooYXFhVOnTpHJZNiwYQNCiIXenWULKSVDQ0OcOnWKjRs3zuoz82GRbwQGgL8XQlwLPA78ByllMbqSEOK9wHsB1q1bNw9f20IL84ORfJlTgzlKFZdk3GZNd5aOTGKhd2vRoVKptEj8MkAIQVdXFwMDA7P+zHwQuQVcD/x7KeXDQoj/Cfwu8PvRlaSUnwQ+CXDjjTcuyZKLe/cd5Vv3P8jpMwOsXtXDq++8jZ3bZzditrA4MZIvs//kILZlkohZ1Fyf/ScH2ba2u0XmTdAi8cuDuR7n+SDyU8ApKeXD4/9/EUXkSwazIei9+47yiU9/mbZsipV9XYzlCnzi01/mffe8oUXmSxinBnPYloljmwATz6cGcy0iX0Yol6uM5Qq4rodtW7Rl0yQSsUW73bnioolcSnlOCHFSCLFNSrkfeDmw9+J3bf7RjLCBWRH0t+5/kLZsilRiE75nk4wDQYV7v36cdHIj8TjE45BIwCzjEy0sILQ75Xj/GJYBCAESbMskk3Ao+cFC72IL84RyucrA0AimYWLbFr4fcO78ELZlIaW8YAJutt2BoRF6ujouO5nPl2rl3wOfG1esHAF+eZ62O2+YyqJOxGO0ZVNkM1kCv41Mqh3kGN/89qPs2LYRPcM5fWaAlX1dCDGIlDGkjJFMWgwPj1EqQak0/fdrotdk7zjQmqUuDKLuFMsQFCs1hBAkYzZ+EDAwVmJFe0t4tVQRBCBl+Dw6VsU0EhiGAbhIKalWXTzPJ5NOzoqApazfZhDUb1fKANP0ABjLFZYmkUspnwRunI9tXSp86/4H8TyPI0dquC4kEnE62js5dOg0N92wgcAX+F47CJ94LM35/hz794efX9F9HflchWQyDkgM4VEqj9HTY7N2LcRiYI0fTc+DSgXK5frnmVRbjhMSfTyutmm0UrbmHY3uFCEECKi6HomYmk5JWkHQxQRNpFLCiROnePDBn/DGN76V2TQ4870A0zKAABDUagG2FQMESBPTNACTsbEKQsSQUkktDx46zMMPPcgbfv7NpJLxSVJAvV1BAMIHwDQNXNeb758/Ixak1spC4Jm9Rzhx+hyWmQE/w9CwwdlzNrZtUSyaJJNxhKFUk+Vylc7ONqRUpOx5sGPb1fzkp88CCWK2TalWoVyFW264jqNHwTSntrBtG5JJRcyxmCJsy4JarZ7sazX1yOWm/h2WVU/28bj67hZmj1LFJRGzqFQ9yjWPIFA3eCAgbQja2hJUql4rCHoZ0MzS1cumwr/+63fYt28vP//zb51Ypu89/bko8Zumo3zYlo2UjFvmagCX0gAkhiHxvIBazSWXL1Iql/nud77N0aOHed3r72Y0V5xksVu2j++74wOBgu8H2Pblp9Urhshz+SLVqstI6Tzt2RjgU8gVcb0KJ8/USCZtzp7rZ3h4FMu2eOPP3cHXv/1FzvXn6O3p5fk3P49X3bmdRx7bx/DwOTo723jRC3ezdnVYiMz3Fem7rnroQQAUAUcftt2c/IUIrXFtACgLIST7QkE9poIQiuijZN/y2yuM5MsMjBUoVTyC8ZveMtQx82VApebjB1Vcz8d2TWquhy/BMgxsy+Dg6SFu3r5mYX/EEoCUMJwLZzQJx6avI0tbavpBcCYL+yc/+TG/8zsfoK2tnQce+DZvetMv8LWv/QuVSpl0OsPnPvdlenp6EALe8Y4309HRyZ49T/LSl76C1951N7/7O7/J6OgId931er7whf/Hd773Y873n+UP/9sfcPbcWQwh+MiffYzh4WH++I//G5lMlh/84Lv87d/9Pc74TaSDm0IIXM/DwcY0DXw/wA98Ojuy83UYZ40rhsiz2RQHD58k5mRIJ3sIAhNkCc/3OX5cYhg1HKeXdLKPmuvx/77wDKlUkrZsihMnTvD9H+2nLZuloz3J+nU93Pb8q7hqS98EKWu3ShAo0tXWda2m/m9E1NpvRvyGMTvyt6zQ/SKlGkyCgFn77TXZ6yDtcvTbH+8f5dDpIQplF9fz8YJ6tvACUM4UqNRcfN+k6vk4fkAgJUEgqQYehhDkilWqNQ/Htq5Id0sz6zloEhceK5Y5dHYQ2zRxLItqzefgmUE29XbTlkogRL0VHX00g17/ttteyA033MSf/Mmfc/XVuxgZGeJ3f/eDGAb80R/9Id/85j/za7/2qwA899we3vzmt/C3f/swhUKJF77oVj784Y9x1bad/Nff/x2uumoH1arktz7wAf77hz7K5s2b+c537ueTf/M3fPgj/5Pd1zyP3/7g77Jly1ZMw2R0rEC5UsUwlPvE93wkEkMIgsDAti06O7JLU7WyGNFMnbJrxyYeeexZzvafolKxSKdWAGqaJQMAQaXq0dmxlhgmcQcQAsdahemYpJMmpgHSNxgeivGNb57l1DUJVvZ2TJCs46hnw5jet22aIfkLUT+9DILQso9CL6tU6kk/ehNNR/768/r7PA+qVUX20+2r49STfSy2+Mk+6tsulqsM5srqGM/is64viTsGwvco1/y694JxlhkYK9GTTTJWrHDs3Ci9HSm2ru5a0oTuuuraqlbD50YDRK8zFfFGyfnMUA7LMLEMc9y9YWJJODuSI5MIiVw/otegYYTLouvoe+TQof3s3Lkdw4B//Md/4Itf/CdqtSr9/ef4b//tjymVoFyuMDQ0zAc+8AcUi/DlL3+da3Zfz/Oe9wJ83+eqq3bR2XWef/3X+zl0aD+/9qv3AOD7PjfeeDOeX+XI4YOsXbsez/PxRUAQKNeLIUwCfCzDwDANKlWXdWt6F4TANZYdkWt1iuf5nD03yBNP7uNbD/yUVeNKFSlhePQYw6PHSCa66OnaSizeBYApTOJ2FyJmIoBYrI1kPIthWAjAMG1sO470BbGkw4F9kuHB0EqOXnwa+nX0Wa+nibbx0WxgME31iMUUKTfeTPoi1+vpAaFUmmztRxElf8MIt6W/wzAm71cjbLtekROPL1yQNqpIEcDAOInPBfny9BUm/EAymC+Rijs4tsHgaJGhXJlU3KEjE180VrqUakYYJedyebIF3Tgr1M9Rg0NfB1EfdiOJ6/+FgHLNJeFY4b0gVCCwUnNJJOrXb9ymvk4bl0sJg4ODpNNtVCoWn//8P/Lww4/wL//yXTKZNK997YvZvPlqqlV48smnue66myiWXAzD4+mnn2Tnzmuo1VyCIGDf/md50QtfyjPPPsWv/fsP8PNvfMv4MZOYpmBwsJ90Oq3ufWFiGFa9L7xWxJU+ppSY5sIoVaJYFkQetcCPnzxHMhHjbP8Q5XKVoeExSuUKz+0/BkAi3s6K7u2ApKdrG4lYO6lkF0IIhGFhmhYxO4UQhjqBwsAwbQRCrSMEIJGeTTUI6O8vYpoCy3IwjOkPZzNLNmp1NFogjeSvL369nmnWWy+aaLV/XQ8E+rWWPEaJWt9M2vrSswJ9swtRr7ZpDEjph2WFwdzoAKUJ3TDqLft4PHRHXSyiFnihXMUQSlI4UqjMmcRnAz+QyHGXi5SScs0nGReUKlUqNfeSW+lS1pOzVkTp3+r79aSsHxr6/EQHW+0SjA722ijQhN5oOcPUQca4Y1NzfSzLhPHZkOsF2KZNPt/8NzVusxlOnDhGX98qhFCuk5tvvo10Os1Xv/olHnnkQbZt243rBjz99NPs2LEbgUXgSzKZDg4c2I9h2Dz33FPc+9Uv8P73/zaDg0M8/NCPeNu/eTeGYbB/37Ns2ryJ/rND9PatJpnoxDBMfL+G61Xw/RogME0HQxh4fgXTWBilShRLnsij+vDTZ/p59Im9+H5AJt1LMtGNbZn4/qmJ9TPpPjavfxGOnUYIgWXFiTlphDCJxVIIDEzTRkxU+BUTT0IIoteYZao3fF/i+xLwME0BTC8jmck10ez9RkKKWj/NPtc4E9A3YPRm1K8bBwZ98+rBIfp/lJj1oGIYobWnZydRy1777aPb0aTfSP6NQVrHmf5YRS1wkIzkK2z+wf1s3PN43bka7ejip//mPdNvbA4IJBQr7vhxEwgExYpLJqms9NFCZc4qF99Xx7GRpIOg3lLWxNxoWTfGUPTx1yQeHWD1eYV6oo9ayfm8+l+fz76+cFCPEncU+v/OZJbjg4PYAZiGgR8EuL7PyvaOCUOg8fNTbSuKDRu2Mzg4yK237uJP//QTfOAD9/DP//w5XvKSO1m/fhOOk8LzAvbt28t1192ERKlS3vjGd/BLv3Q3d95xA7fe+hLWrFnPhvVbeOvb3sVDD/2IV7/yVmLxOFu3XMWfffSv2XrVTkZGhnnNq2/jD//wI+y6ZjeWFSPmKCs9CHz8wEV6gBALolSJQshLYbLMgBtvvFHOV6u3j/6vzzGWK7D/4Akee+IQlaqq1ZWId9DbvZ1A+gyNHKFSzdHduRmBQSyWobdnB4awWdW3i1SiGwBDWONWtzGuLTaYzKkzO4jVDXLlaQKjA0R0RtE469BTdT1o6EFEiJBg9HvRWYYm90wG2tognYYj/ecJ8AhkwFCuTN99X2dTA4mDsgjzmSw/fsf75+33ivE/2aRD1VX+9EwihpQSLwjoaUvh2CY71vZOcm+Uy5OJuZnrC+oD2qBItlqtP7Z6YNUDp+uGcZFoAL1xkI3GZvRrHXvxvHpif/7zn2P9+h2T9k+TfSPy5TKDhRxV1yVm2XRnsqTjiUmfvXhIQNYNOnq5fl0sFkil0gD89d/8OfncGL/9238IaBmiJJDjB0RNuvEDD8+v4rrl8dm2JBFvRwiDaq2A65bwvCqWZbJ6VQ/tbZn5+DETeO6559ixo/54CyEel1JOytlZ8ha5zrg8fCRHb/fV5IrnMYTN0MhBAumzqvcaVvddy6Gj3ycea0PKAMMwcewkmVQfIHC9CpbpIAyhHsojDjQLkMnpyf3K4+8JNFMvXHp0jT+rm9mo3kNs8y+QdodZVdjLC0a/RSyoIIBMPkdm8Dz57hWTttJ76Dmuvf+rdaevmEjyw1/+9Sm/2fcFgWcyVAK3FsOUcfLSxvdABgYjVgzXDzh/ZHwP5WTXVVRp5I/HVhuVSZqYtQJKrx8lYL3tqd6byoLWiAYTp1ovut+NaLZ+KpYgFbsUxF23xfFtCqLErSAmjuPf/d3/5N57/wnLsrnxptv4r3/w5xiGyYQhK8AQ5sQN7wc1fL+KWysSi6k4WbWWRyCQMkAGPkGgRkghBP3nRwDmncxniyVP5KtX9TCWK1AojNHR3kN7Zg2b1r+QPc99mSDwiDlpkokudmx9FUMjxxgaPcpQ/2EsM0axOExP5xbiyXbS8S6kaWNgI0TUETg3iYaJuplaSTqXGwIQBFaaMmnKTjcDqat4qucufnP/r0ys0Xtk/yQibyRx14hTNlIUaGfLZ/+Vopnl6LYbybWtJAgEQSAUYUiB9A0kEPjKADDHTWfLMsn5ABaVsXoXk7aUoyqlqWR8LUyPRuKeSpv0G7/xX/gP/+H36m5nOb6ulMF47EsgCfADj0JpACl9giAgaTZkdPquInMpsW0HxzEJ/ID+8yPEHKclP5wt9u47yj/+32/yxJP7KFdrBEFAuVqgAzWSGobNru1388iT/4hESQyz6VX4QUBXxyZGO7fhOAm2bLwd205iGSZCWA0EfmHwaZH4YkIsd7ju/2Qhx7YfPcC6fXswPQ/fsnAtC+Vhh7+66n/jmU0SOsaAaZKwpsPohX2shQtCM6eaWt60NKyQgCCQAZ5XpVIdw7FTJOLteH6NajWHadj4gav4QUAQeEgktp0g5sSRsoxhGviev2DqlSVH5Hv3HeUjH/tHjh47QyadxLIthodz+J6SjKWTK3DdEtnMShwrgWU6gCAWy9LethbPLZOIt9M/uI+xsVNk0n0QS2Nbygc2Vwu8DibjTH7xv/NKgRBh8E3XmslkIJtVfvCOjtAfrtU4Wn+vYDKSL/PwcycZLdZ49cc/PO0ZTA4NsGqwHymVi8xya1iuunYMoLPWz3A8hiea3IwXdG5bF8PCYrqrQSJlQM0tU63lcawkqUQXwjAIAmWNO04aIQwlSzRsGgcEz1MulmDcQs/lirRllS/+cpa3XXJE/q37H2RwaIxsJk0s5ihLukswMDhCPObQnl2N55Vw7CQ7t70W369h2wmk9MmkVpDLnyWV6CCd6mFk7ARH932JmJOhp3MrmXSvkiIaJta4flRZ6RExbAQ60Wbi//o/SxZTKVm0Xlw/RxUmyWT4v64no2vK6Melmql0ZBIkYw6jxRr9azbQe+pY09s3ALIjQ0gp8bHxBTiyPuvlF4/9dwA+tvUvwboYf+fSvgaWDqZzukevgtCNEgTqpg2kr4i8VkAGAbadmJiVG4aSFNsiroKhgY9p2hjCwLJiuF4FIQwC6VNzayrhyTIRhuDc+SEAHNu+bOVtlwyRa634v3zt+xQKJbZsvIFsqgdJQBDsJR53EEJg2zYSJQ/q7tzCgcPfIZtZjWnbxJw07W1rMA0L36/R3bmZ0dxpKpUxTp79GbncKeKJdqQM6OnazlUbX4Jh2CqZwbAwDQuJliAqn5omc3OeWSqq+mj2f5RoG8lWE2ejxC9KrlpHrOu6NMsI1a+XQiZnrqxkHE/c/Tauv/fz9J46VreOBH7ylnt4wRf+AQHcu/ZXiQdFXn36b5tu8zcP/lrD5wVV4VAzE1SNFFUzQdFO88QdbyHwTQLPUD5zaSIDwZZV3XXBRx2krNVCJUn0dTNt/nTSvMblS9G/3tTTMcWyZr9fBzmlVEXPGj6FsrglkgAZBBMb1/5wgcTzKth2UkmODROknAiC+ogJK9wwLSWIEMZEvoghTCVBlGCZNoZhkohblCvqWkzEE0AwkUh0Kd0uS4LIo1rx3p5OCoUSwyNFkvFeTCugWnPpXdFFNp3CMAwMw6E2Hm0OAg/HiuM4GRDg2Elk4OMGklgsQ0fbOirxMdKpHp7Zfx7PrdDVsZmu9vWh+N+wkQH4aPNbZTiYpomPj2kKHEdF56OJPVE5XpRsG7M6Nek2y+6MZl02yvKakW6UpBc7Ac8XTg3m6qa7T9z9tinX9S0L260RCJNjyd18bPuneNmpv+SawpMY01p3krisEveqSEbxRIyCl+HG+z7Ho3e8HQILr2pTq5kErsmRSj1JX4iSpBGNA3v0c/q6aMwv0JK86PUYfY4qZLRRoF83JqRp99dU+zzbZbP9jc2WNw5wytVhRL5LqiCmDML3MRCGkhhK6eP5LjWvjOuWiDlZYk5CBTyJ2PcTElrlSjENGyEixto44SvfORiGQyIew7J8ZKCtfwshXEBe8vK2S4LIdXeebLqTVX0J9h88zvDwOQRxujrSFAolNqxfyStvfwk//skQpaLA9ctknD62bXkF8VgWy4pBECAMBykCHCeJaVj0de+gVB0lmWhn3aqbOHTsexw9+SBj+dP0dG0hmezCNmP4gUfgewTSI5AubZkYfX0dVKslnn/zDtasTtRlwjWrt9JIuM1IuFV/fO4oVVy0A2wm3jixffe4zlxStZKAxffW/gadxz/MutJzMM02JPCpDX9EPtHQPPzQ5HUrMwRGG2dXjVm9elljElaz7MrpBoZmRDqd1ev7avu+X78vjZ+ZKsMz+no2JD5TRudUmcST1SoBQqhsWz1bHt8CgQzwvSrl6hi+XyMYd5P4gRcxAMat9MhOKGni+FXVsNyy4ghhEvjuhA6dcUNPGCJC+mpHL3V52yVB5ForPjic5/x5l94VneRzLq5bo39ghEBarOztZsO6ToqFND99+CxCSNxaCdtM4Po1LEslaggBBhYIiRAGsViaci2PYVh0dWxESp+aW2Zk7Bg/e/YLtGX6cOwUvl/D82sEQY2urjSZ7CqybWt45R03s2tn9xVj/S5GJOM2Y0WDmGNSaSh01Yj9L7pDEbmUdRkBA/F1E0ReTmf5/jvDxKFrvvsN1uzbgwRWVQ5zwspgShcrqGHJGoWVfQhDIgxwbIPOTBzHnkzOjYlSGtGZW+OyZhm80XUaSztElzUOEI2EqP+PSiOjiUGNj2aDSCOalZWYav8b19fP2iUVlWU2knd0gAuCAN8P1DpIpfGWHq6rUuolEtctI2UwnvlZIRbLUqvmMU3lOtGZ2xP7BRNlOdQMXI4rXEAGapQLpPoelcEApVKRWNzBskwMESMIgnHX66Uvb7skiFxrxQ8ePkrc2UpX5woEA1hmnI72doTh4DgW//yVb5CMbaK9rZvR0RLF0iDJZAfG+NRLGBZ6hBWGiedWsCwH36/iey4xR6Xox500mVQvq3uv4dipn9btS093O7/9H97OXa9+4eU/EC00xZruLCO5MpWqNyurfLSjC0GAkAFy/N4tGykY/+yxXdfXrV9KZyf0TK859w9170ngxPs/2tBFaPrbqpEoG5N45rLsckFK5eePxydb/lGSnUt5icZtNMs01e9FB6roQATjROm7BNKlVBrFD2oTpN7sYpiQGU88g1E3pxv3pRvaqhZ1Lng/8LBMB88rY1lWmFEaSFzXY/XKHlzXoFQuTqhWLnV52yVB5K++8zY+8ekvMzQ0yvq1Jqa5Dd/36WhPkk62M5YvK8mPLHPg4HGu3rYO24pTKll4fpVUshvGfV1S+oyPtepECoNEvB1DGMScDF5QJZvuYyR3glhMqRYcJ4VhmGTTJqtW9fDjnz7VIvJFhI5Mgt2bejl0eoij50ZnnNL/9N+8B+8+FzGREgIlO0MAVJwYOx76Pjse+j6gSP+ZO17Hpsd+MulmkYBnWXOudhgNUi80oqQ41WChl58/HxY6a0bms/0+/awToxpnA1FErXr9Gb1cW+W1WhUV9ATbSWD6Np5fw/ddEOPZnpER3hAmlukgLbUx3/exzMlUKMbtcuU/NyYs9ljMVoMOAZYZxzDUeqlMGtf1iMUUYfdm2i/bTH1JeGR3bt/I++55A11d7RQKZZLxdjra+kgnkwiRortL1SPIZGJI6eN6LrWaSzweI1c4jhhPw5fBOImPSwoty0EgSMSyCMPEtmI4VhLLikMAMSeNaRr0dm9h9/Zb2bhhFUjJDx/8GXv3HV3QY9JCPToyCW7avoa+jjRxe2qGjFsGMhAMrdlIYDlKJWjC4c03MtLRTaJWnfCwCqB9ZIhdD3yVn77lHjyYmEZLoGaYHPvDP2b/yUFG8uXL8CvnH9H6N7r2fCoV1rPp6ICuLlixQg08WvmklVDR8sWNj6PHj/Lxv/0cv/eH/4OP/+3nOHLs6EQpXK3a0codbV1Hi7ZFY01aEJDJwDe+8Rle85obePnLr+Huu1+IH1Q5ePAZfuHfvJa773oZv/zLb2V0ZBgh4J1vfytnz54mGU8wNjrGm974KgzD5Fd/9R38we//R37+DS/jr/7yI5w7d5b3vOdNvPKVN/CSl+zgZz97BCEMTpw4yj33vIHXvvZW7rrrVg4f3o8QAsMwsS1bSQ6FQBhywg+ua9tcTnfrkrDIQZH5H/3++/i7f3iIdCJFpZqmVC6TTFTp7ekCzpEvlNi8cQ0116NSq5G2JJblIQwDmxjCMCLZm3I8QCFVAoDnYRgxujo2YtsJurrWUKtVScTS9HQn6ersZTR/BoSgsz3Lt+5/kJ3bNy7gEWmhGTqyCUzT4OxwDi/iLhdAImaTilv0tSc4tbf+c4FvEsvVEIAL2JHPtY8Mke9ewY9+4/eoegG2ZSADSVsqzorxQePUYG5R1CBfLNBKs2w2xYqeLgYHC/zlJ77M29/6BjZv3DgpyNksbqDlsdG+tPl8nr/4i4/w5JNP4jgOo6OjDA3ned+/fTsf+dOPsWPHTv7ub/+Gz3zmb/nAb32QM2dOs3nzZjLpNA/99Cjbt189Xq52L695zd18+MN/hcTkta95Pr/zO/+dV7ziLsrlEr7v4Xkev/3b7+HDH/44mzZt5TvfuY+/+qs/5aN/8dfKlQLjChWJbUn8IKCzIzsRML6cWBIWucbO7Rv5xbe8gEQSkskYNXeMrk6HeDzG2FiFsVyRl73kBrZt2UjcdvA8l3g8DkJJhYLAH4/OhzU7lVrUQAYeAkki0YbjOKzsXcemDStJp7swLQ/H7sL341SrNXZu38DpMwMLezBaaIo13VlMQ5CMOXRl43Rm4rSlYmxa2cGGvjbSiRhXb1hFMlZvw8hAULTbAfjkVX/Jl9b9Vt37AvD8cYlaoGpzZJNqCm1bBqVKk35+VyikhK9/60GS8RQxJ02tZhCLp0kkUjzwvQfr/OnRh2kq0u7ogO5uNRNIpepdUKZpUi6X+a3f+i0ee+wx2tvbefAnP+Dmm25m165dBEHA5s1bGBkZ5ty5M6xZuxbHsQHJvv3PcfXVu/A8l7GxYX71134bYRjc/+172bp1O694xV0AJBJJ0ukM3/72Vzlw4Fne9763cMcd1/OhD32QRCJBzHHUzMEQ47MFC8ex6OnqIB4fvyYuc4/cJWORa1z/vHWkkuvwfcjlT/H9Hx0mN1aio7OLt735dgz6KORg+/ZNnDyZJ5PtwqSG7yfx/CoSG5XjJzBsE3M8yBGLpZFBQDyWJOYECKMLLwjYtH4r1doZajWXVCLL9qu6cBx7Ig23hcWFjkyCbWu7eeLgGSo1n7hj0ZGOEY9Z1FyfZNzGdSGbiDMYiWBJ36RotQMn8A2HvN1dt907PvnnGJ4HwiDX08v517+RWKwTUA0TkvEru7u1lPU9ao8cG2BFT1ddad50Ksm5/oE6OaVlKTeM48zOik0mkzzzzDN87Wtf473vfS/vfve76e/v58Ybb6Qtm6ZcqXLkyCG2bdvO4UMHufbaaxFC+cGffuoJfuEX38XRo4e5/vpbyKST+IHg2Wef4vrrnz/pu/Y+9zQf/OAf8Qu/cA86U1c3Z8mkTSwrNiHbTIxPxqIlgy8nlpRFDuEUrFSCdWvX8Iafewkvf+nzefMbXs/O7RsnOplk0hm6uzuwrSSe72GaBslEBsuMEYul1Sjvgz9+MxsiRjAevIAKXuBSrVX4uVe/go0bV7B6VQ+7du7CcWzGckVefedtC3YMWpgeHZkE129dRVc2QXs6Tswxqbk+ruezpjuL5ykrKmrpOZZBpb17Ivhp+ar+iv7fHL9DhQxoO3+W9f/yT1hnTtdt90qDlKouei4Hw8MwMgKFglrWt6KHQkF1/xaoe7JUKrF6VQ+plLK4e3qU9T2X1oAHDx4klUrxtre9jbvuuotKpcLq1as5dOgAfb1dSL/CN79xL7/5G/8eGbj0ruihp6uDw4cP8q//ej9XX72b48cOs3PnbvxAqVpWrOhl//5nJ75jaEjNtvt6V/GDHzyAlqzs27cHxkvl6mtHyvp9b9Zo/XJgyRG5RqmkLhpQB29IlTeYqP1hmhBzEmy/6ip6V/RgWTZgRjLWTExTjJedVXrPRNzBNC1qnoFpxLnphu1cvWMn77vnDcQch9GRgLZsmvfd84aWf3yRQ1vmjm1Srno4tjnRrada1ZK28UgnJo4Vp/1992CsWwuAJat1yrVRp5tn2l+MYanmI87wEM6TT9Vt90qAlKpBxtgYDA4qAi8UlCUeVZ28/CW3USgWKRQLGKaqZ1Jzi7zxdbeRTl+4YudDH/oQ27Zt4/rrr+fo0aO8//3v5x3veAdnzpxh9+7dvO1tb+PTn/40XV1dvPKVr+S+++7j3e++h+999wG6urrYumUz+/btZdu2XZimgWkavOUt72Jw8Dy3376LO+98Hk888RBCGLz1rb9CEAS8+MU7uOOO6/irv/oIhiEmlRxubFl4ud0qwNLsEHT8OJw+DZ1qZsuRI6pa3gtfCKOj8PTTcO4cDAyoyPvYWGgpGIaaHjW2rAJV+KlcVpX2sllYs0a9fuEL4cc/Vifs+ZNnYC0sMRw4AI88oq4jPRVua4NXvQq2bYP//iGfju4qt94+ypbf/12Qkh+seBtHM9dyz7H/Ms5YAdZLXkziV9+7sD/mMsB11T1WKDzHmjU7ppQJ6uQjy1L30uGjR7nvAdVLd/WqHl59520LagCVx4VFUsLISBVh+CAtlIdZTLyn0vLrtfFaK+446vQnEuFgFI+HiUzVavj/xWJBOgQJZd48BpyWUt41X9tthrY2ZQlEoVtfJRLqIoJw6hfuY5h63Gz8qlTCEwJqULBtdbNHyb+FpQ3dST56s+nCVuUyCEw6szaObRLYNka1SsFqx9C1dqQEYWD0dDX/gmWAahUOH4Znn4WTJ6FYhFe8or7aZ7Rwm5YgRi3tq3ds5Oodi3PmqnJKtFZcQTI9Aev3Gstv6OULITvUmM9g538AngMuubMwk1EXjh4ZU6nwIGqNq77IogkEtq0u0FgsbBQcDcYEgfqc5+lEA7W8UFDf16z7dwtLD7Xa5IFcNz4eGVH/p9M2uzf2UnzFSwm+eT+umVB+c99Td/zKPsxbbrrs+34pUa3Cc8/Bnj1w5kzoLtGIJuE4jrrvlkJ1zGawLJNazQUjkisUJnXWoTFrNfqbo799ITuDzQuRCyHWAK8FPgR8YD62OR20n7tUUtZ3Oq2sZxXECn1WUetaByU0uev1GhveGkZI6HrULRSgvV25avQUq4WlC50KHkUQqOtJE7mu8pd61zsoAtVnE1jSVTd7RzvO296Etb6heNYSRLmsiPvJJ9X1Hb0fhAh13J2d6l7r7V0e179pmiQSMarVsGnzRDftKaDLA0TJujHTtdFffrkwX1/7P4APAlNW4hdCvBd4L8C6dfNzAxSLimDj8fD/trb6AkHR1F9tUXieej1V1/IgUBd4Mqlej41BX596r1ZTFn0LSxfRTEIN7d/Us65owCr2spcQnHBI2jXsG18PpRLBY0/grVq5JMm8WITHH4ef/UwpTqLHwjBC42jHDrjhBmV5g7LWlwOJh2V7rabGXOO6jbVhov9rUtcz/4WqXnrRRC6EuAs4L6V8XAjx0qnWk1J+EvgkqGDnxX5vLBZ2F9cukpERReThd6rnqFTIstSzbYe1G6IXsnarlMtqnVotJHVQr1tEvrTReM6hPm0cQoscwHvoUXzjpWB6FLwEmZRiAv/hR5cMkefz8MMfKp93uaGagGUp4l63Dm65BVatWph9vByY7Yx6qpovEwWyIoW8YOFkhxrzYZG/ALhbCPEaIA5khRCflVK+fR62PSXa2kJLAdQNODwMGzaEU6DoczRQAepmzmTqb16NqLVeKCgFi7b6SyU1C2hh6cLzJt/Qmsh1cDxK5Pn+EoEwQRgU/RgZuwLJJMHA0OXd8TliZATuu08FLaNBSlDGSE8P3Hgj7Nq1OAp4LQQupPGF5pXGAWEhZIcaF03kUsr/BPwngHGL/D9eahIHRa7a9aH92iWVf0AqpQ6qttQbL9J0WrlL9AlphG7f5vtqm9Hg2NjY8rZYrgQ0s8ilrJ9iR2/KfHY1wZjEMTxS5jjTl0qLUrVy9izce6+S3zYik4Hdu+G22+qNoCsJUQKeC4lHywjo19ofrq+lhRwMl1yKvoYOVpZKSrliGPUSRJ0UpP3h0cBnKqUI2baV5dU4LdK+dS1VBOVXNIwwCamFpQttkTeiVgsDodGgVWnjbjgeEA9KJMwKslhEFopYt7/0su1z/kN/Ck8+Xb9w3VoyH/0T9u2Dr30tNGQ0DEPNUF/6Uli79nLt6eJFs+JcM6FRSx71gWvi1vyxkPGDeSVyKeX3ge/P5zZngibyTCacPuqO7tMFHkxT3bjpdFgfQsN11Ump1dS2pFQE7jj1uvQWlh6iDQwatcC6eQLUE3kh1gMpl0TsPMbQEKKnC+v2l142/3gzEv9J26t4NPVGgj+o17wlEvCSlyiXyZXqLmmGcrnMnXe+im9+87tYljowc7XItTXfqGoLgrkf61qtxite8Qq++93vYs2D1GXJWuSgbrZqVVlYmYySIPq+8v8lk2Ez2kb/oJTKNaNPwFR9DXWt5FJJ+d8dR0X5W1i6iLYQa7SgdFKQDoZr5HIgLZvkdTtIvKg+0+6y4Kk9AHiY/K/tfw0iGm332bjR5PWvV9f0csfVV1/N8PAwqYhvqL+/n1//9V/nQx/60JSf+/SnP83rX//zEySucSE+8mYdjoTw0YW1ZgPHcXj5y1/OP/3TP/GLv/iLc9uJJliytVZABR1jMUXoOkGoXFYHO5mslyFG6yPoqbMuYj+VCkXf8MWiIvMr4UZZ7oh2pGlEdBodJfJaTS1bMLXS+I7pPu+2X+Llx/83H3jul/jAc/fwzncuzmvTO36C6j9/ifJffZLqP38J7/iJi97mu9/9bt7xjndw6NAhDh06xMGDB+nr6+NXfuVXeO6553jxi1/MNddcw5/92Z+xZcuWic997nOf4+67XwfAZz/7GW6//QZe9CLVmELj0KF9vPnNt3PHHdfx1re+guHhQQB+7udu5fhx1Ujm3LnTvOhFN2Ca8OY3v5n3vvd9vOQlz+cjH/kTzpw5wxvf+Eae97znsX37dh555BEAjh49yute9zpuvPFGbr75Zvbv3w/A61//ej73uc9d9DGBJW6Rt7WFskDtKhkZUe4SnX3VWOAmOpq2t6u0+3Q6tOyj6/m+el/LEdeuhVOnWklBSxmuGxJ5o+step1Ep8quq9ZNL1Tl4vEL2cHjA/vet0A7MTd4x0/g3vsNRDqF6O5EFou4934D7n7tRbmk3vnOd3L99dfzx3/8x1iWxfe//302bNjAunXruPnmm/nUpz7F8573PP7dv/t37Nq1C1BujCNHjrBx4waGh1Vjiu9+VzWmGBwcBaBarfKe97yR//2/P8euXapA1ic/+TE++ME/4vTp46xfvwHThGeeeZpdu67BNGHPnj284Q1v4Yc/fAjL8rjhhhv40Ic+xF133UWpVML3fVzX5d3vfjef/OQn2bx5M9/85jf58Ic/zN///d+za9cuHn300fk43EvbIrdt5ROsVMJpsa6CGA1wRi3xKLGvX1/v84pCr+e64bMeNFp+8qWLaCykcTDWbheYXJrUMOpzFC4rrt099XvrFmcU03/4UUXiqZQqQpVKIdIp/Icvjri6urq49dZb+frXvw7Apz71Kd797nfzL//yL1x77bU873nPA2Dnzp1ce+21AAwODtLe3j6uOlGNKf7gD36Ln/3sMdra2gG4776vcNNNL2TXrusA2Lp1J0ND5zl27DBr127EMMQ4eT/Nrl27cd0Kw8PD/Of//AdYFnzlK19hx44d3HWXKjOVTCbJZDJ85Stf4dlnn+WNb3wj1113HR/84AdVsxvUvjiOQ34ean8saSKHMFW/UlEWdVRV0tiFuxG2HQa1mk2bdbH8xs82qgNaWDqYjsiDIDzf0dyCalVZ6Avlvsj83gfhumsmvzGuWlmMCAaGQstHY5609+95z3v41Kc+xdjYGD/84Q95wxvewNNPP8111103sc4zzzwz8X8ikaBSqYy7XJM89tgz3HzzC/jAB97L3//9xwE4eHAvO3aEA+a+fXvYunUn+/btYceO3RMc8sQTj3H11bt59tlnufnmWybq2j/55JM8v0lp1KeeeooPfehDPPnkkzz55JM888wz/PVf//XE+9VqdYLYLwZLnsgNQ5G4DlBpa1kX9Gn0jUefQelqHSdM7W+E9rFXq2FGXKt41tKFDmbqmhmNRY/09aOfdcBbF4lqxKXwAzdD5vc+SOYLn61/LFISB5TGvtHimSft/e23386BAwf46Ec/ypvf/GYcx6Grq4sDBw4AilQ/+9nPTljkHR0d+L5PtVrh0CHVmOLnf/5t3HnnXVSrqqRpX99qDhxQjVyPHz/Cl770f3jTm97J6OjwhNV+4MBzfPvb3+Caa65hz549XH21GlyFgL6+Pp59NmxOMTCgmlOsXLmSb3/72wTjJLRnzx506fChoSG6u7ux5yGTaMkTeXt7OA3ORCq9JBJMdAuK6sGjtZSFgK1b1etarbmESOvJR0fDgkpjY5fil7RwOaCJHOrdJzohSEtPo4O2LgPR6H7TfmBZLNb5gS8VmS8lmLfchCwUleZeygnt/XxUjBRC8K53vYsPfehDvPvd7wbgHe94B4899hi7d+/mU5/6FBs2bGDTpk0Tn7nzzjv5yU9+zJ/+6Ye49tptvOxl13PixFF+6ZfeD8Ab3/gO+vvP8PKX7+b9738bH/3op+ns7OIlL3kl3/veffzbf/uLfPWrX6Czs4tVq3rZs2fPhK8c4F3vehf9/f1cffXVXHfddfz0pz8F4J577iEIAnbs2MF1113HRz7yEcQ4GX3ve9/jta997UUfD1jiwU4IA55atTI2pp71MsMI1SeNdRIg7BmoLfrGAjo6Y9Q01U1tmkrF0sLShM7S1anWUejZW3R5LlcfH4lC+4Gr8XbOl9pYlxqcWL5UarBcKljr18Hdr8V/+FGCgSGMedbe//qv/zo/93M/N9F4IR6P8/DDDwPwZ3/2Z7zhDW+oW/9Xf/VX+djHPsYnP/l/kFIN1DpvQPXcTPDpT39l0vesXr2W7373SQxDzcj+83/+A0wT/vzPP0qlErpm0+k0995776TPJxIJvvjFLzb9Df/3//5fPvzhD1/EUQix5Ilca8Y9TxHy0JA6SYlEWOi+kZyjFnkQqCpvTz+t1q9WJ6tcdA3zUklNsRtrs7SwdBAd0BtdaVqpFHW55PNheeRGBANDiO5OzhXbcYPxFZZADZbLBWv9uks2oKVSqQlVCsDHPvYxPv/5z2PbNi94wQv4i7/4i7r1r7/+el72spfh+z6GMbfsHZ2aH702dKzlQqsd1mo1Xv/613PVVVdd2AYasOSJHJQlPTKiXCu1mrLKUyn1mM7q0sWxdM0WXfmwWR0OXf62rS1UxrSwNKFr7EzVFEFfM1Iqi9z3wy7pdev1dCl3SmCRscd9MYu0Bstyx+///u/z+7//+9Ouc88990x0h5otosX2giAc0HXpjwuF4zi8853vvPANNGDJ+8g1SqWwwFW0ETM0ryesl+mRtbdXEXrTKbQfjsLNbugWlh4aXSvRnAM9yFcqYbyu2XVh3nITXr6MrFbodMbm1Q/cwqVDM9HDTDDN8KGxUE0kmmFZELnW92p/l06j141SG4vB6xs2OkXasEG5T5rJELWOXPvI9bIWli4MY+obURN5oRDO0popVqz168i9+G5ELI41PIhIpbAvMuGlhcUJPfBbVphrsJhq2SyiMeXCoVP1dTVD3SR5qtrBUVLXRG6aYbOJZl1DSiXlstFqhkqluZXWwtJBM7dbFIVCqFiZqgZ9Lr6S2EtWkth62yXZx8UGKeWE6mKpolmtlKnQ+FOjvQou5WGQc5kusEws8mhgM5pGrZc1HpNmFjmoinEdHc2t8lpNkblWrLTK2S5d6IG90aKKqpvicRXk1slA0w3avb2Xdn8XC+LxOENDQ3MmmaUMbQxGC+9dSLXDuUBKydDQ0JwShZaFRQ6KtHWQM5fTkiJ1wDVpa0SVC1Ei37ABnnpqCoXCuCtFW/tjY7BixSX7OS1cIkQTxLQCIVqeVK+jm49oiVmz2IgezKP5C8sZa9as4dSpUxPJLksRutaOrknfWBm1Edovbhhqxq5FD1MFyucL8XicNWvWzHr9ZUPkhhFWKKzVQtdHNGlK37RRn3nUheI40NWluolPpV7R0epWUtDSQ/S8R9v/aehrIQgUkQ8MhBZ5tPWbRn+/el7inoZZw7ZtNm7cuNC7cVE4ehSOHIEzZ9TsemBg6niXaarZ1vbtiku2bg0TA7dvv6y7PSOWhWsFQqtIN4kYHQ1dLo0ZfPq5mf/rttumdq/ogKeuhtjC0oIuiqWny8105Pr6SCTUOfZ9ReLNiNx1F7CQVgsXBF2+urGgXjPo60O71mIxReKLKcipsWyIXNdKqVbVDTY8HJ6Axt6c+uTpmzP6XiIxtZqhUlEDhC6u1MLSgusqq1u72vQA39gK0PfDGZ7rKn95YzkMbb339Fy+/W/h4qHdI7NJ5NGul0bNeF/fpdu/C8WyIfJUqt6Vol0f6XS99RUlbU3kjaPyzp3KRdOszGk+P/euIi0sDkSJHOrPb2PpBl1vRXecaiRynauwmLTELcwM3ZQdZpYQa5eaLtGgi+UtxpjIsiFy3RVI10zRro9MZuoGAo1dsDWuvlpZ+M1uUt1hpoWlh6kscggt8mgbOF0aOR6fPKgv1il2C9NDu1ZmY5HHYsqgi8WgszOMiSxGLBsih3D01EWwQL2OtnxrXB8ma8aFmLp5c7RNWIvQlxb0INyYIAaTLXK9PkxN2Itxit3C9NCuFZh5Zh2Pq/yBdFrN+D1v6nyChcayInKdfWfboawwSuQa+kbWRG6a6qRF8YIXKAVLI2q10D/eajCxtKAtcu0Dj5J246CtrxnLmhzovNJkh8sJmsgbOaEZtOwwHg+Ntu7uS7+PF4JlReRRBUGtpgKfzZKC9GvtOmnWgLmnR02rJtWg9pROPQjCUgAtLA1E+3U2Iuob152hprLIz527tPvZwqVDtOn6TNCqFsuCwcHw84sRy4rI0+mwvkq1qgKeyeTURK5PZjOLHGD16uays9FRRQotLfnSQq0WShCjpYyhXtkkhNIY65u4kfgX8xS7hekRLdsx03r60dWl7vl5aORzybCsiNwwwpK0oJQF2rUSRbOT2Ew3fvPNzadS1WpYe6WFpYNoVmfUAteI3rzDwyGRR2Mo+vVinWK3MDO0a2U6MtflPlKpsFfrYi7FcNFELoRYK4T4nhBirxDiWSHEf5iPHbtQ6ECnlEpZoAtpNevdGbW0mhG54zT3g+oaHLqvYwtLB9q10iwZLBoEHRoKk4eiOQOLfYrdwsyYjUXuOIrAU6mwvlK0jtNiw3xY5B7wW1LKncDzgV8VQuych+1eEHQzCSHCYGSj20RPrXXwC6b2mV1//eTP64BZS7Wy9BB1qzR2CoqSuy4nkkzWu9dGR1skvtSh417TBTt1YqBtL27ZocZFX5JSyrPA2fHXeSHEc8BqYO/FbvtC0NamSDlaN1jrgBsLyruuekx3Y27ZoqZUx4/XL29Z4wsD7/iJuj6Q5i03zan+dzMZKtTLDw0j7P0ai02+PhbzFLuFmTEbDbkQagDv7FRuts7OS79fF4N59ZELITYAzwMebvLee4UQjwkhHruU1dO0KyQ6JW7mNoGQyGdCszTsSqWV4Xm5MR9d66PnrJn8UFvohUJYQdOy1OvFnNnXwuwxGx257s+p3SnNpMiLCfNG5EKINPAl4DeklJOEeVLKT0opb5RS3thzCQtU6PoqphkW0Gpraz4Kz5bIX/jCyfWoy+XZf76F+YHuWl9yOsh5SUQqhUin8B9+dPbbiGRuRqGn2brVmzYCbFsReaWyNKbYLcwe07lWkklF4qOj6v/FnsU7L0QuhLBRJP45KeW/zMc2LwZaueL7SuudyTTvEOS6syt+1dY2ufZ4uayCIK0GE5cPwcAQJJN4WJwvt+MG5gV1rY/6xaOWuIbumK5J3DSVNd6SHS4fTOVig1D9pnsbTDWjX0yYD9WKAD4FPCel/IuL36WLhybuIFAqAx0A1dDqBO1Dn40+dGeT8G2hEI7YLVx6GD1dUCrRZqso9rHCill3rW+8aRstrGj8RNeidxx1HVlWKDVtyQ6XPnRTiKmI3HGU8aZlh0uhgcx8WOQvAN4B3C6EeHL88Zp52O4FQ/fwdF0lQdQVEDWiMrNmWZ3NcMMNkxvw5vMtLfnlhHnLTchCEVkssj51DlmtMDxmzqprvR60NVE3KlWiN7Umck32yWRLdricYNszE7njhAZes8bbiw0XTeRSyh9LKYWU8hop5XXjj2/Ox85dKLLZULVSKIQ1yaPQN7BlNc/qbIRhwKpV9csqlVaa/uWEtX4d9t2vRaRSWMODZNOS0RvuxFg7s2pFxzKiHYJgckKQXkcnA9m2ss5KpcWd2dfC7DGTRS6EMu6W0mx7WdoXutCRYSiZYHR01Yh2/5itD+zFL4aDB8P/y+VWg4nLDWv9ugm54Xpg3z51TrZtm/5zUSLXZUyjJB4NfOmAqO+ra0krF1qyw+UBXZN8KiK3bWWFS7n41SoayypFPwpthWudeLOaKTB7ixxgzZr67C7dw7OFhcP69eo8zBR0jhJ5Yy1yvTy6rq7XY9thzsBU11ALSwszuVZSqVBiulSIfFla5BD6xatVdcKaWd3aDzoXv+fOnfDII+H/LdXKwiKRUDfmqVPKKp+qEXKUyJtJyRot8mgZ2/Pn1fJCYfEnhiw0yvd+He9bDyhJVyqFccuNWOPKogtJ4LoUmInIo+d4NslDiwFLZDfnjq4udZPr9myJRP37Qihr2jTn1gX9pS+t/394uJUYtNDYtEk9nzo19Tp65tTYzq1ZsFO7X3S3Kd9XVlpr0J4e5Xu/jvfPXw7rRxdLBN96gNpTz1xwAtelgC5l2+y+te2QN2Y7U18MWLZE3tamrPBaTZFtY8Eb3UB5rkL/RKJ+W6OjLT/5QkMIWLlSGYFTlU5otMibnfdoMTWtXNFob281EpkJ3rceAMfGj8UpmBn6nXUcyl7PoaEuJMYFJXBdCujmEtMRuRBLKyaybF0rjqNG3mpVEXljZqYQ6ua+kKnT7bfDvfeq1+WyIvOldNKXI9ra4OxZOHoUtm+f/L4ebDVJN6bkN97UWvWUz0NHhyLypaRimE/MVN+mVoNnnoE96V+m0pEGBCvcU2zwn8aQHh2lkwhWqpUvIIFrvqFn4VMRuR7kG2fxixnLlshBWc66nG2jRa67wFwIrrsuJHJYeCKv/vRhvPseQA6PIDo7sF51B7Fbb1m4HVogbNkChw6pyoWNVSCkDGum6BZejW6W6Lpa9VSpqJs7mq59JUHXtxHp1IR7ZPgr3+XQ1rs5MtxJpRKu69tpOqv97Kg+zjr/MKJcAs8Fyw6P7ywTuC41pjLgUqnFn47fDMuayHVGZ7U6OfqsFSdz8Y9rCAHJ0mlKMdV99/Rf/D9WDd8HQOYLn73Y3Z4Tqj99GPezn4dEHNrblB/ys58HuOLI3LJUoGpoSD033pCavDWRN/ZxjVppth2+39sbzuh0MaUrBd5Dj3LeXs1zhS2cPd+BF1jIwEc85WOsUMfmuuuUoqvytX3KR+7Yod7XD2DjGqSUUCohC0Ws21+6kD8JmJrIu7vVoL3UMniXNZH39irLqlKZrEzRN+yFBDTyb347b3BW8bnNfwLAycRWbuK+ifcuJ5l79z0AiTi51CqG6GJ16iQOo3j3PXDFETmodOrhYaUtb3SxaCWKdq1ELfJGUrdt5XOXsn42VyyGqdvLEdWqChofOgRnzkDuxEvBtBCAbfisTg5ydfYEfaWjpH7lPXWfTdx9F2WoV63c/uI61Yp1+0sXXLUCUw/GiYS6NpaaOmlZE3l7u7KkqtXw5m3sEnShkene2pmJ1ycTWy5uRy8CcngE2tsYoZ3TrOU0a5FJCZUKa59ROusrrezqxo3KV57L1ZOuLpQUbemm0ehaMc1QTx5FobB8iFxKVWLi5Ek4fVq5pMbGwnsjFoMVqSI744fY0j2KZag3ZLGIWNGc6RJ33wV333W5fsIFo5nkuFlZ46WCZU3k8bgaYXM5dVM2y+S7mMpmu09/gT2r3oTtLZycQXR2IItFNqSOs4Hj1LA4XenhdGIz/f31pVfb22HdOjVtvBCX0lJBLKbO/Zkz9ZUvPa++Z2cjoTerv9IY8NI1yZciXFeR9YkTyt8/MqKIWwf943E1o9myReVLxOPgHQf33v2IcgqZTC4q98jFoJlFHospd2yjMGIpYFkTOaiTUqsp94ouTwoXb5ED3JH7Onfkvs5CysitV92B+9nPq31IJLDLY2wo97P1DbuI3aqI6+xZ1eFodLQ+YJdIwNq1qoaMvrAvtgPPYsH69bB/v7LMtc48SuQwvVWuESXydHrpaMmlVN6Ns2cVeefzyqAZHQ3vhVRKxY5WrYIdO5q7E6z16+Du19ZdE4vFPXIxaGZxZ7NKobQUqh02YtkTeTYbFs+KjsL6hp6PtOupjNviP/wfgu/+UI0kjoNx+4tJvesdF/+FEWg/eJ1q5Y2vn1huGLB6tXqAusEHB0Or7MAB9QAgP0rf4adY217BiSRwcPdrl9yNK4T6zadPM6Gs0DVUGqsfNkrRpPJMTbrZFzORe56ysE+fVqRdKoVllnW+RCKhiKqtTRH36tWzm5lF69ssFzQj8nR6biU7FhOWPZH39qqTlstNJnKtXpgrMl/4LPk3v73pe4k//2O84ycof+Lv4OCRsKC15xJ86wGKcEnIfLaBTSGUNC8qz8vnlcV+Zu9ZTol1nCo6UBxfV5xn3Y+fpmsJ3sg6NnDsmHp23ZDIp4MQoTol2mBbBz2bdRi63CgWVVD3/Hl1/mo1tWxsrJ6429rUfm/erGYprXoxCs1cK0v52Cx7Iu/uDhODoidvqpobs0XmC59t6oYAqPzVJ+HoMQBGrW4Op68hJYtk/RGcHx2i4/XqoonFFoeULZOBXbtg8w9+hFjdydlyO4fzqxiqpRiSGzl2Ik/6AUWAHR2KEDo7F57MZoOtW1WFxGi54ag17npVajWJ7weYpgHYSGkihCLCaCs/HSArly+vH9X3lWU9MKB+Rz6vlhWLyvKu1cJmwamUcg2sWqVcSsslMDvfaBYbW6puFbgCiDyRUBd4uTxZmXCxTQKaTTkLf/QnyKPHcEWMqpnAFRb72m5V7hddOveL4frxeNgfMJNRN55OIdcF7vUjFgsL/lwKGD1dyGKRVSlBTzzPaC3NSN4kaIuTTygiGRpSKodEIrT61q9XKfKLMdJvmirIu+fZEQ4fOUGl1MvAUI7Nm3uQUjA84mJaaUzTwEeCD+ATBCY1r8D3fniIL977fVav6uHVd96GwcaJGveXCqWScpMMDyv3SKmkiFuTtg7cG4a6LtrblcGyYYN6XgzGwWJHM+u7u1uR+VLEsidyIRRBNjbOjVrkI/kypwZzlCouybjNmu4sHZkLy8+V+w8BMBDr4xtrfg1f2AgCbFnDDGoIITFkOH0vFtVDd6DR+6b33TDChtLJpPot7e3N/XiNpK9fz/bGNm+5SfnEASuZpNs/S1dQxL7ztYg1yiocGlJuqbGxkNj7+8MqhJalkkPWrl08U9Wz/Sd46OGjyKAD23YZHB1gaLifdCqJaa1DHZ4AExMfied7HDtxlnzhLMmUS19vL2O5HB/52D+yuvcGPK9Gb1+ZV995Gzu3b7yoffN9dSyHh9V1kM+r2aPvq2fPC4lbu3va2xXh9PUpy3sp+nQXGo0WuTaolsIssxmWPZGDcgOcORNVLPgEATx7fIBhcliWRUc6TiJmUXN99p8cZNva7gsjc1+V2VtVOc6ukR9yNrWNopXGx8IXNkEyg1sM281F285pi1b77qMXVaWiiLRRQhkEYbBOT63TaXWzt7dP39Wm0eqPda3DePVdiMcfQQ5OVih0dYUZslKqaf7QkLISSyXlrx0aUo/9+8M0974+JXtsLJNwufClr/yQU2dLrO67DUNYFAtFBss5DNPiqk1r8fCp1aoYhoNlxQl8H69WJpA+Q8M5LLuGxOPosTPUKt1cc/VOxnIDfOLTX+Z997xhTmReqSjSHhtTs8RCQRG176tzqbXrQaBIXEp1Ttva1Plcs0a9XqqEcznRzEADODWY4/gJHwjraiTTklRq6dLh0t3zGaBP4ki+wkDFplTtIvAMogUfnZhPuebhl13itokfmAznK5QqLmeH86zuzrJ1tWKuWVvsqRSMKYfsrUNfg6Gvhe91tJP58F8ipbK2BgeVVTs8HFpk2hrT7cagvlF0VF2hSVy/1jLLsTE1cOnBQA8Wev2o1Z7NKkvEsnTH+LVYO9eGgeAysC/8zkarf+XK0OovFhWJaxfA4KD6XUNDSgYYi6nv7+hQboCOjktPSHv3HeXBh/biOJ2kkwO0ZzNIKbFsi2rVY3QsTzJuEwQ+QgQgA/wgoFgtYZlQkUUOHTlMIGtk0klKpWGEIWjLqlHpW/c/OCWRB4Ea7IaH1XkpldT/zZpYaOKu1dT50ANxb28Y52lhajSSdjYZ4+xQHj+QlGsuQ7kSR84MgwTDMMiVIErkNXI8euQ8z54zEQgqrgcSYo5JKmbj2NZFz9YvJZbl5TGSL7P/5CCVmsdwvkxgmXheB74nAIl2VleDCqar2PLUYB7LEHhBeJed6M8xNFYiGbdJJ2Kzstitu1+D97l/gsh2AEilcH5ZqVWEUFO5NWvUYypowi8UFBkODCjfaamkHtVqmK0YJf1ogE6/F31fZ7maptqmJvyoj1sTtU6qitYemfitE+Q/+dlxFAElk2HCSaEQumROnQrToZNJ5Wfv7Z1f/+7efUf5xKe/jBBqWnL85JPEN/dgGhmKpTGCIMAPVBeJQAaYUgCCIPCoVlzGqiOMjB3n+JkDOJZFMhknkEV++OCPGRw+rG7uRHzCxVKtqvMzOqpIuVgM5Yp6ZqTromviLpXU79dWtw64LYWGv4sBI/kyB08P0T9SJGabtKfi1FyfPUf7CfwA15eYpkBIKLtaguTzos/8FYeu+suJ7Vy751uYjkP/lh112y/XPEYLVSwDuttSFMs1VnZlyJWq8+KKnS8sOyIfyZd5eN8pCmUXzx9nHitQ3B1IhFtDGjFA8rIv/DVttXPc//7fBagjcQA/CMiVa0gpScRsRkfLuH6AIQQHTw9x8/bJLDxRb+LLX4NiSTFkdxfO298259onmvDjcUWKU/WljBK+tgC1hV8uq/dqtZDQoy4dTfrRAUFLM7VVbhhhw9qo7tqywvrsmoy0e8Dz1CMq39O/p1JR+6MDeZ6nlCW69HA8riz9NWvUNqfy9e/dd5Rv3f8gp88MTAQjd27fyNe/9WP+/rNf59nnjhAEAW2Z9YCHZcUolUdACIrlMWJ2GikDhBC4XhnDsLFFgiBwMQyDUrlAvlilUCghEAwOj9HZkaFU2osQBkGQpKdrE5/+zFO8/KVJ2jK9lMtq3/Sx0MdSz7R0E+dEQp1TrTLp6FicweLFDG2w5UpVHFsdvKF8mUzCoVL1kEgs08TzAzw/vLfv+PiHkaLeSd5R62fz/U/wM5hE5gBeAOdGigAc7x8j5ph0ZRJYpjFh2MEcZu7zjGVF5HuPnee5E4O4fr3puP3HDzBcTDAUW42Jh0cMEMT9EiZw58c/PEHmUQihDOtCxaXYPwoIDCEwTShVXEby5aYn6nLXm2gkfJ3J2Ajdh1JP8XWmnw641mrhI0rq+nPaxRN11WiS16QftdI16ScSirBSKUVY0X6qY2Pq4XmhW+j8eZXY8rOfhdtNpZQbSEs2h0aGeOiRI8Rja+jpXMfYaI6/+4efsGnDAb7wlfuo1UoUiiVAUCwdQAjBiu5tVN0Clhkn5mQgkEgZ4PseYCCRCMDzPYQwqHklTNPBj1xPtVqSrvYNZDOd9HbvJJPJkIyneOzxEa7d1TtRoG14WB0rLQ207Xp1UkdH+Ft0E289S9K/earnFhRODeawLZNASizDQAgBBAzny5imwPVASllH4jBOejJawzpgReUkBnDt/V/l/iZEHoUEPC/g/GiJ3g5BImZz6PQQfiCxLXN+Ym1zxLIh8uP9ozx3fBBP1pP4th89wKY9j/PkupcyGFuLQWgixoIyAphqNq8N9ECizh4SHzluZRlTWuWLFYahSDWRmL6prFZMlMtKRVEoqId2FVSrioS1FE6Tve8rMo4OAFEFTiNJaVeMZSlSy2RC90o8rj4/MBBmJ46MhIPI0eM54rFuYo6J54FjdZBOrOaRR/vZvO7nKFeqxJzn8L0yp889Q7WW5/zgftasvJHRsWdIJjtJxjsplYbJpFZgCAPDsNT+Sp9SZRTfqzI4pFRIqWQ37dnVpJI9WGacfHEM0zyClD309q6iWPAmfk8QqBmEdktZVtixKhrvmKqb0aVEdNBt9jzVe43B98WAUsUlEbOwTQM/kJhCYBiCmutjGIKtP7yfdfv2YHoevmVxYvtu9r/oDgBEpLBGsnaerKtkY7P17HnjpHBmKK9mAK7PivYUKVttwRl/PjWYaxH5XHDo9BAYIBu62q/btweAlaXDnEhdDYBdPU/cK2BJt3Ezs4IvIW4YDOfKF7XPixVRueNMhF+pqEexqCx87bvXySraraOldL4fDhRaqRENxGo09tNsLHDlVrsxzXa8mguGi2GUsa0cg0PHSCWTlCsBhUI/ufw5qjVV6co0bEBSKJ3n/OA+dl71GiThxk3DAgKK5WFqtTxDo0fJF/sxDIudV91Fb/d2Eol2Aj/ANE0MYWIYBpWyTdwxOXIk7DCjiTyfVyTY3x+SoZ6tRAkyGl/QyxuJVLu3GhFtURf9v1EZpY9xNB7SrHDYVEXELga6T+ZMs41mg8lU35+M29Rcn0wyRvnQUTbfdy+Z/rN16+iomO3W2LTncdr6w6qlNx3935yIbeS1xQfqiH0uCKR6SCnJlarYlkEipgIitmVQqlwYx8wVS5rIo5Hq0WIV2RhgBEzPQwA7Rn/CY113EvPLvPvIBy/qewVQdj2qnsd3njiMbVl0ZOKLIugxF2jr0PfrSXaq11ompwlYv25cBorQ2trUa70d1w2JXW9LL9efj1r3UUKK+tqDwEdi43omggRCmGCCT8B1V1+F73vUvDKrVlzHqXNPcfL0oxRK53G9KoPDhxgdO0W1lsfzXXy/igx8gsDD88oEQcBY7iT5wnlct8rqvmsxDJuezi3EnTYsI4awDCzDQcoAKSReDWIpZ+K3TXe8pyPi2aJxG7Ml2sYiYc0+30jeUcKP/j/ddnUMZaqCZJqc5zpQNH6PL7uonh8gPjiMqMY5ab8JsUZZygYSISWG9NmUf5yr8o8jgK7zIdG/sPIYVB6b3ZdPg5rnkYzZSCT5UnWCyF0vIBmfRv87j5gXIhdCvAr4n6iZyd9JKT88H9udDsf7R9l7fAAp5fg0RuIFcmIE1vAtC8Ot0eEN8Ov734/XUOJKopL5tjzyI5KFHKV0lv5N28h3T52rKxlPKDIMap5P1fWxDMH+8vz7xDSJzYZoG6v7aSusGfFG/49aco0B0ahPPNpFJ6p+abQudWKQtq6auVSaEYImdU3sOiiqFR7VKgyP5Dl2/DxIg3LFRwgby0hg2XGEMLFMG8NwcOwUqUQXPd1Xcd3ONyKRBIFLoTTE6NgJntr7LwwNH6FcGSMZ78KJpaFaoFA6ykjuNIHvsn3zHdh2kkQ8Sya1AmEYGIa61iQSKaUiFgM1nTc9bMuqG4Ci56LxvE53zqd6f6bPzbTO8oIAfwU4K6Ah+czwc1jj6dRnEhtIeHnWlg9M8EOoXQshgeoFBCGqboBjSfwgGD93EtcLcD2fTSsvT6roRRO5EMIE/gq4AzgFPCqEuFdKufditz0VRvJl9h4fANSNVCjXQjJqWPfE9t1s2vP4xEmzImtoEu/ftov2c6dJjwzR69ZY/8wTHL7uFo5d//xp90MIECi/XKHsk03EOXQyz1WrE5OINkq4jTK+KDR5NhJtlISjFuxUZAv13XAayVb/r9eNTn0bFSvRazvqCmlmWUXdIY3kH933qLKl8RE9XtGBKAigv7+GZbdjYiFEmXLFww98ZK0ChgApMAwL07AIpDLjDWFhGjamYZNOChwrwfOv/xWOnfwpY7mTDCa7cKwkxeIQJ889SibVx5aNL6G7YysxJ41hGKjbXp3regoYf22q9pTe5ZlJtwBEwl2TEJgprh66D0v6mHjEfKUDlYBnOwi3NskfXjUMvvtvL2y2XnU9hBA4FpSrHsm4zaaVHUtKtXIzcEhKeQRACPF54HXAJSPyU4M5gkBiGoJCpTYxCjbD/hfdwcr9z5CoVetuPwkMrVjJ8LpNdJ84QmZgkP7EZnzHwhMWzrNjbH722zz56jdNuR+2ZZADbNNABga1dIJKNYBiSFSNUryprKVGwm30ZUK9pRz9TGP2ZiPBRxOMoq6L6AARdbNE14tuC+q/vzEzVS9r9jq6vhCTB7Nmx6XZlNv3s+MyRIEhHFIJCJD4E6MZIMB1y3i+i+uVMU0L04yNf2+Aado4dpIVPdvIpFcwNHKElb3XMDhygO7OrWxe/2I6OzZgm3GEMBr2pZHIFUyuvH6eC4/IRRPUyPhjxIIydlBhVfEALx744qRP+ELgxuJ8/z0fmLe9EGhjRdKVTXLTAggg5oPIVwMnI/+fAi5ps8hSxcUQMFaszipE8f13/yY3f/Ezdf6xoRUreeRNv8QN3/wiHefP4hpxqmaSc/ENVM0UgTDxMcn85Ahntl0LQpOERCCQqCi5lJIgECANBjAwhcPYudDNoZ+jZBoNPjVL1Z/LNHw5oZnvt/G3m2aAjyJONXgbyCDAEONWsxEAAmnGsA0Hw3AIghpB4AIC33cpVAfJF89x/NRDABjC5MDhf2Ukd4Ju36VazeN5VUxhAgJhGBOcIYQxTuPaeaxuZDXQLTJZxyJHdMYXla7qvAHdrUcXvovF1EMrr2r/9f/D9soY0iPljmLPIF7wxktKH9t1/cQyrWqLIgCevPN1TfXkzaAvUSWFnMsRmD9ctmCnEOK9wHsB1q27uNrWQoDry4kD2OgXb4ZH3vRLTZdbVWWpO0GFI85mnuluov8+P7f9G5mgmhbmgtkNYBb44JtgGOryNYzGgJLEMpUnVMo4fuDjuiUCGWAIk0p1lIGh/YyMHidf7EcIgzUrr2fDmudz9vxz5Avn6OrciO/XQCgSNzCVdT7B6cGEJFXztyHE+Eyq/rZqrLrZOEA1BhSbzcYaVSYajYlc+jumMgoat9uoDIp+R/Sz0X2y7TB5K0quOiksmQzftyy1frPHxWri8+1lOHtOHYdp1pOAb5jUkimO7bqeY9c/nzs//uGJO7Rx+DWB6+//Kk/QPDmoEZYBQghitnnZgpuT9mEetnEaWBv5f834sjpIKT8JfBLgxhtvvKhxSwCmITDEuPTnIrblxcLScWVzCTbru0SYSknQTIkwnUJhKuLQr5upHaKvJ/vfTVy3RqVSoSY9PNfHH9fVST9AEoTbQCIl1Nwi54cO4PtVgkBSLg9x+uyT5IvnSSY68bwqZ/v3cOsN72btqht5cu+XGMmdJJXqwTJtZBCQSHSQSnQhDBMZePiBhww8hGEQt9PYjoEkoKM9xTW7tk5IDDWiszCtwdexgsa4SKOMMIqpXFk6vhE9H1HZZpRQtXWrLdt0OrR6o64hHbie6tH4Gy83nLe+idrf/QOUK+C64zwgJhK7AKqxOM++5JV1hKxJXK8zZnby/9b9NtcNPsDN+e8pxQuwYc8TsyLyAJUgFE9bE4W5Ljfmg8gfBbYKITaiCPxtwC/Mw3anRCChpy1JpVYj8GdnkQNkBs/zgn/+9JQT4LvPfwbOf6b+u4DAMCEA37Q4s3Ebz9x+FzIQSrkhBKZhYmDg+xKkoDOTZuf63knyvKj/uTGQ2UzzC8192o0+7KhVFvV3T+WHjj5Hv2eqdaKfb/y/sQfmbF5PJ1Vr5j/X36XXsW2HRELJFKq1KmNjOUqlat26AH7g4bolXK9CzMmQy+dxvTIYBqXKCB1t67jlefdw4sxjnD77BKYVxzBMdm+7m9PnnmJo+BAnzyh5mmnYZNIr6evZgW0niTkpEvEOOtr7iMc6MAwT27JxXcHJk3NzhzVzKUV15NrVEI8rAtaWr7aGdaC6GaaziC9lbfvLgWibw2BwiFIizdittzGwdQf9IwW8Kcx0TeIVEaMsA86ltlKOr+Kna95B2+kSO3LK5RYvza6vn5SQTjns3tS7YPLjiyZyKaUnhPg14NuoY/RpKeWzF71n0yAZt8kVq8QcG1nzCAI5442jSbzZ9d5MihRdbgUq0mf6AWsOqxjuMy/7OfWeYWAYBjIQCCFwfUk8Zk2QajTQ2Ui4zUi4MVDZrD5KNGiok0mickAIyTA6TTaMydZa1AKu++0Nfvvo1L1Rrhgl32bW91SkPhNmI6fLihg93T0MDo1w9twghaLSggOMjp2i5pWQgU+hPEjgu1Rrec72P4Pv10glO0klOknE21nRvQ3bSjA0fJTO9nWs6b0OwzRx7DQnTj/Miq6rWL/m+XR2bNS/hJiTxrZMDMMikAFSBgSBP1FN0nEUWUat30QibBAym4qGQoTbmcoqvpIRbXMY5MsMDuaQuTLpZIxK1aMyUShrMv566/8CM45V0bEzg/OJ9RNEXknOru6yJQx2rOtZ0BySefGRSym/CXxzPrY1G6zpzvLQwClMQ5COO1RcD99X7bqm8pX1Htk/QeIfW/0HkFzDe478Nml/bFqLXgCPdrycZzrvmFgmJRSe6EJMYiVF/UNxm6dnSJxolmAxXVadfq+Z1C+6nej7UfVLNJtwum01ez0f683HNqbzoUvZAXTwhX95gKHhMYqlAucGTuONy21On3tS+byBeKwN1ytj2yni8Xa6O7fQ230VpukQT7TRll1FqTxG4NfYsfVVeH6Z46ceZnD4MK5foa/naiwrNr4th3hcvU7EHbZftYGXvWx13X5OZxFPla3ZwtzRkUmQK1U5NZDD9QICKUnHbQpTZVeayp/txVdOLCqZWTxhY0qXY7uvb/65BnRm4+QiM8KFwJLM7OzIJEjFHWquh+cHZBIxsskYg2NFfCkplCefuGRB1Qh3hQPZzQA80vVqbj//+Yl1vtVQOOvVH/+wShKwssT9PAYBQkqEDDCynXiBhxCB8skJQEraszE6s+YEeTZquBvTrqdaPl3qcrNH9Ptg5gFhrus0vl6M2LvvKIeOf4fTZwcolyuUK1WKxcqk9VZ0b+PE6UexrTiWaZOKdwIS1y8Td9KYpo1p2gwMH6Qts5KNa19IqThEItnJidOPMDhyCNOwsawA16vi2Car+npY0buCt//Ca9i+/fL/9hbq80t0slZjAb06yACEDiz4IEzyVhdn4+s5/6KrZ61aqdQ8Tg/mFjSze0kSOUBHRtUd1sVpQBWyMgEh3EnWWymtghAFq31i2dn4prp1Xv3xMCFVZ1oL4ObBb3Ld8HeQwgBh4GNw+pf+C6mYw/FzY1RqAY5lsqannb6O5CSXSaOLZC6INh1YLGi07ptp4Gd6b7r1L2TA0LXH+3o7Gc3lsSyT0bHQx2kYNmtWPo8Tpx8BwLbimJYDQmCaJgiDmlvANGx83yeXP0Nfzw7O9T9L74rtbFj/QkbHTtLZvpHh0aNkM0nashkCGePqHZvYtWPTvLR+a+HCcWowh5QSxzKp1Hzl8pQS21QuT42KZWF7cpzI9VIXMCnaab5/x3vpWjM8q++M2yaGKRCIy1rtsBFLlsjXdGfZf1JVLLMtA9cLiNvWRMS6kS/7N21jy2M/YSQWdgUp2O0QWTfKHw5MiAgdWcXxqxPrVleswLArXLOtm2u2ZSM1X8YYrpUWdGSOBkunG0xmWjbVezpIu9hw79eP05beSSppkUknOXj4FCdPqUatvT3bacusouaqsrYAu7e/HstU5Yz9IMA0DWpeBcdRHR36B/cTj2fIl/oxhx3SyS4s0yCZyODLAV72ohuwLHPOrd5auHQoVVwc25xIFgykyvkQQpBJ2BTKNSTwvff+R277u4+P61vGEQRgwmiql45afpISqxlsU8Xr/EDSlY1jGsZlq3bYiCVL5B2ZBNvWdtcVct+0UpH09586StWtn1Llu1fwk7fcg/OdsPpZzUjVkfj3219PwWnjrvOfmShvW0pmSJbyE59xs1n63/ILE3pRXdx+oeoQN0K7RZZy3eoLGWgGh07Q092BYVTo6e6gp7uDp545iBDQnuklkWhDiNDlFo/FMIw0amhWG4nZSeKJdkBQreVUk+We1ThOio5sL8OjeWKxHmznZWzakG5Z4IsMybiN5wfkSlUs06BS8wDVOKQ9HccwBKaAkUKVH7zj15EPRwq0mOP3qlTCBbdi4yTqXbRCQCpm0ZaKM1aqYRoC01DbTsRspJSXrdphI5YskYMi82Zkef3WVTyy7zTBeMq2Jut89wqO77gWRtX/npniW+//XV798Q/jI/jZytch8OskiEff9nay+/YSGxsltnIFY9uvptjTy7ZII1fbMidcPJe7DvFyhPb3zwUrV5qM5c5M9NIEsG2LwA/oaF9JzOmkUDyvVCZCkM2midltGONRVGEIEokEyUQWIXw2rOul5g5gO0XWrFpHzM7S3bOJWq1CzNnEa+/czVVXzfMPb+GisKY7y/7yINlkjHLNxfclXuCrGFoqhh9I2tMxTMPg9DmfiWxtYGI+Lk2SdgK/lMbOjEIAwjBwLIPdm3pZ39sOwJ6j/ZNcu5ez2mEjljSRT4X1ve0UyzX2nxrED1R9DZTEG7cYr1s3cNWJGLH7ABANjSky27cwtnYtlZpLOhEjGbfZFnGd6OL2UVzOOsQtKLz6ztv4xKe/DEAmnSRfKNHV2UapWMayYjhOklQyTl9fDzu3bcSybXzfRBgGPT0d49ULbUxsfGBlXxc7rtrNvgOHMU0DcKnVDKq1gFuf34PrwoEDsHGjkhO2sPCIztIt06C3I13n5txztJ9csUrNDzCFNYVUTdCeStPeEaNvjTVl27Zmrt3LWe2wEcuSyAF2bljByq4MpwZzjOQrFCs1YpbJ/iA6YgriIoOLQX9KBT6NcSJXab0GNdfHNATXb13V1MLWxe0Xy8h8pWLn9o2875431PXw/N3ffCdf+Mp3sEQcx47T2dFOT+963vKmV2DQx5kzcOwYdHfHKZWqFAo5fF9imvDKVzyf7s4+Nm5I8tOHBikUBkmlkmy/aiuB1zGhAT96FFatUm3oWlh4TDVLh1C2DJK4HUOYSqxSD6U4M7DZ0tdLYopJdXPX7uWrdtiIZUvkUH9SR/JlDp/KI/2GObub5NHf/n2qn9kHgMDDRyANwcH/9sckbXPaE7TYRuYrGTu3b5zks960cTVf+eowhUKCdeu28JpXvZSY3cfp06ESSAhIJGJksz0MDamWc6v6suOJPCt50QtWTnT7GRtT8YdqFXbvVp1/zpxR3ZD6+hbgR7cwa0Rly25NZWXX87jiBp3c1t8PGzZMv73F4j5d1kQeRUcmwdpOfdDD0+dWHdZkE/x010vgvIGMp/E+8Ul6OhOzKuG42EbmFuqxc/tGcqMbOXkStm2DrVvg+HH1nu6ZqRUKuhlGPK6IubcXBgehp0eRfjar1i0W4dAh1UB5/XqVMn/unGp11/KbL25o2bJfSGCaAr+JrLdYhBUrVI/Y6Yh8MeGKIXJQJ6gRnqs6/UhfHQrHtsgk53ZYFtPI3MJk6BIG0TIDEFrkUtYnbyWTqsl0JgNnz6oMzK4uRdS6fV2hAKdPw7p10N6uyP/YMdi3D7ZubaXOXwp4x09Q/rP/Af2RcqTXXUPm92bfDELPoD0/AKN5cY5iUZ2/XG7p1JhfwiK1uaMykeQXnplKNcCxTRzTQQjlH1uMOukWLhzOuMosWmQMwvOsn7Vss6NDrROLQbmsCH/leBZ3MhmWCzh5EgYG1PJ4XBE4wMGDobXfwvzAO36C8oc+Uk/iAE8+Tf5Dfzrr7egZtGkaDcHOkBNKJVWd0vNgaOji9vty4Yoi8vrGuCZg0pnOsn1NL4FvTdzIbktwsqygOyhFm0ODeh0t+OV5yvpqbw/XKZfVQwjo7lbv9/WpqXe1Co9HehKYpnLfCKGCoGNjl/ynXTHwH34URsaQwJfW/Ht+0hnpG/DUnjltqyOTYF1Px3iTZJPG3gG1WjioD88uwXPBcUUReWVy2Q0qFTVFdt0wk2sxpcO3cPHQRN7YmanxuVpVN282q6x4bZXpgb27Wz1XKqqM7Jo16kYfHAy/SwhF5tmscsucPUsL84BgQJnGAig43TzV/UpG7fETcgEttHTVzqlQLqsBvZk7djHiiiFy11U+r0ZUKsq36XnhiW1Z5MsLja4V/TrahzQIwhmbYSi3ythY2OVGY+VKReK2rVww5TJ897uTv3PVKrXu2Bjs339BXNNCBEZP18Tr7aMPEwiDp9tfqhZcQHGe6VLwpVS8YBjq/DXjjcWGK4bIi0UlH2tEuazkY3qEFqJlkS83RF0rjfXd9XLfr3e9OY5SLSSTyv+toYOduZyqLb57t3p94sTk721rU6oHKRWZt2IvFw7zlpugQx38q8d+jB1UOZx5HiUzA9funvP2prLIdfyjUglfnzt3sXt/6XHFEHmhoIIYjfB99V60EUPLIl9e0BZ5Y0emxk5E2mcOSqVSrYYt0KJYu1ZZ456n5IeOAz/4QXOibgyCNnPvtTAzrPXrSPze70DvCtJ+jo7qOcpWhiM7XjMn1QqE18F0RF4qKR5Ip5eGYXdFEXmzKZKUoSWmZUYtIl9e0K6RqCUeLSmsA55BEF4DnZ2KdPVnowScUgUSGRmB0VG45RZF+k880fz7dRDUMJQbb3R0nn/gFQJr/Toyf/kXZL7wWXa+6RqM9k6Obv25BhHDzKjVmrcThPCaqNXUOc1kFHfM9TsuN64YIgflRmmGRh3xUhiBW5g9NBlHe6dG2+VpSBm6UZLJUHoIk91yGzao4Ge1qhKH0ml49tnm7jtQ33PVVcrdcu6ccue1cOHYvl2do6EhpRCaC3TD62aIEnk+rzghn4fz55uvv1hwRRF5M0tb39SmGd7wLYt8eUFb2Z43ubl1tMWcYShy0MtdN9SDNxJ0PK7WGR1VxPyiF6ntPvjg9IHNlSvVI5dTyUOtIOiFob1dZdy6rlKdzSX+MJ1FrqFjJqmUcp0VZteHecFwRRC5lGED5EZEk0FsW93MraDU8oIm8miwUw/gGrprUbqh367nKdJuNrXevFmRSbmspuDr1inLbd++6fenrU1VTYRWEPRCIQRsGm/wdfSoSs6aLVx3eiLXSYHlsrpG2trUdTAT+S8krggiL5fVjdjshtHLpVQjb7Mmvy0sbUQt8kYSb7TI29pCqzyVUsHOTKb5dm07VLecOaPI2TQVkTcLrEcRi9UHQady+7UwNdatU0HnWk3NjGZ732qLfCr5oTbmgkDJDx1HzcgWc3LQFUHkhcLUKdPaZ6oDXS0iX37QsY9okFP7vrUlrt1rmYwi8lpNPdt2SOTNrouNG5WPvFBQPvOtWxUpP/jgzPtlmsrXa5qqkFcrCDo3dHeHRH7ggKpWOBtM5yOH+vhJqaRcrqVSfeLXYsMVUTSrWJw6QyuaIGLbs/OftbC0oC1ubYn7fn0JW8NQ/xuGsr4SCUXMWnao161UmFSf2jSVO0YHMDMZRS7Dw4pcZlMNcetW9flz59R12uudwH/4UYKBIYyeLsxbbsJav27+DsgygWEo+eeZM+q4jYzMrpRwtFjaVAgCdb6LRTXYLvZyC1eERV6tTl/ESMvPLKtlkS9XGEZ9MlA0NqKVStpyTyTqg5vaUp5KkbJunbLKx8YUkaxapYJxR47M3mWiPzd2ZIC9X3iGoFBEdHcii0Xce7+Bd7xJxlELrFih3GG1miqHMFv3x3QVDaOZvlp2qmdpizVl/4ogcpjeytbErafZLYt8+UFb5I1JQaD+j063DUMRsG2rz+n8g6lStYVQCUSFgiKTWEx9NpebnYtFI5uFNaceRCTiHA624EsTkUoh0ilVNKqFSdBE7nlKvTJbmeB0zcmjEsRaTZ1HfR0s1izPK4bIZ1IGaD9pyyJfnoi6VqIWuX4t5eSbW7tJQFnp0+UX6IqIw8MqmWjFCpUEVKkoF8tsYQ/3s7lrBICSP94MNJmcKBrVQj1sW82I0ulwAJ6NVFDPxKZ7XxdNy+WURW6ai1eafFFELoT4MyHEPiHE00KILwsh2udpv+YN+uabKTNL3+ia0FtYXohmb2ofeTQhKJrVqZHJhCVtG2WJzba/apUi7sHB8KaPxZQ0brYuFqOnC6NcZGv2LFl7/EOlUl3RqBbqkckw0YrvyBE4dWrmz8zULELHTmxbnbvu7rAv62Ik84u1yB8AdkkprwEOAP/p4ndpfqF9WtPJwbR15rot18pyhZaUaf9nY11yLT+F8EaNx8Nl2lqfbpDv6lK68oEBJV3UpW4LhZkThTTMW25CForIYhEppXouFFXRqBaaoq9PuVcsKxQrzNTYQ8fDpoOuiFoqqWvBNBWp62YiiwkXReRSyvullHrC+RCw5uJ3aX5RKKgbc7piRUEQNpRoWeTLE1FtsJSTrSrDUJZdKhUGNaM3ujYIZrKs169XBJDPh31A169X37t378z7aa1fh333axGpFHJwGJFKYd/92pZqZRokEmrAbGtT5/jcuZnT9rXUeDpfuU4irFTCa6JYXJxlbedTfngP8E9TvSmEeC/wXoB16y7fRZnPq1F6Oos8CJTlpV0rLYt8+aExu7MxZmIYyret+3U2flYvy+fDhKFmaGtTlvnZs6pKos4MNU1FMGvXhlP0qWCtX9ci7jnCccKqhaOjys3VrN+mPu+ayKeyyoNAnTutIx8dVeetsWH3YsGMFrkQ4l+FEM80ebwuss7vAR7wuam2I6X8pJTyRinljT09PfOz97OEaU4/1YoSuQ52tqzy5QXdgFlb5VEtsT7X2awi6VKpngB0wBNmZ43pjE1t9Xd2qvoqpglPPdUyFC4FVq5UxzmVCgn4+PHJ60VjZdNZ4xBWRg0Ctb3ubhXzkHLxJW/NaJFLKV8x3ftCiHcBdwEvl3Jx0l80AaQZtH806lppFvxqYelCEznUdwrS/5umspx19cOo1dzWpgKYicTsgpaplCL/c+cUwQwPKwMhm1VJK88+qxpStDB/yGSUFV4uq2M8PKwG5UbLudGlNp1VrRtMmKZ61vGSYlFlkUYH+IXGxapWXgV8ELhbSjlDdYnLD33jaitsuvViscmfaWH5wLJC67txxqVnYlo3DvX1VXSHoblMpa++Wj1r4li7VhFLe7vSOi/mdO+lCi0/9H2VnOX7k3umRi3ymSSI2qBzHGWRL2ZOuFjVyl8CGeABIcSTQoi/mYd9mjfoAOdMEWyo7+sIrZrkyw3RvptRH3k0b0Cn6sPkVHyYXWq3RiwWWvJSKgvOMJTFn07DoUOLU8a2lNHTo/T7nZ1h569GV9hcLHKd6RsNeGr3DSyuQmcXq1rZIqVcK6W8bvzxb+drx+YDOkA1mwMebQcWfW5heUDPuKDeItdWl/aXNlOsgLo+tDU32xv4uuvUc6mkPrtpk5qmd3Wp6flcEoVamBmdnapUgo536XT9oUgu1Wwtcj3AawGE6yq/eGdnmJMw2yJdlwPLOrNTE/lMJUUhzOxrEfnyhC5R3Nh4WUOrGKaqp6ITg2DqdRphWcqPqkusDg6GPSBXrlQ+9Nkkr7QwO2hiLhZDCWitVq/7jlrk08kP9TVSq4X1y6Ot//T/iwXLmsirVXXgZ+Mm0USuXSstIl9eaIyBNNYjj8fVlLlRsaLR1hZ+bi464muvVc+6v+fq1er/ZFJZd/39szM0Wpgd2ttVgpAWxmn9vx58PS8kYyFmLp7leeE2tIs2asUvFp5Y1kQO6qTNxhdpGOqktnzkyxM6kBmtgAih5ZVKhfrwZo0k9A0frdMyG2hXSqkUDgI9Pcqa6+hQ/x8/3pK7zhd6epRMUN/P58+r83X6dLhONHg9kwTR89S5su0w4NnTEw4GiyVofUUQ+Vws8mh1vBaWD6LB7KiCSROobiihX0+FC7kurrlGkcaZM0pF0TVeNqVaDV0sBw/OfbstTIYecPUsp1oNz7F2hUTLLsykRNLlbHWq/thY6GaTUs20FgOWLZFHq9vNxpclZThSu26LyJcbGi1yDd3yq60tJIGpMjfj8dnVXGmEYSgrTtd0qVSUHFFKFZzr6lIB1MWWZLJUkUqpATKbVS41LUE8dkw96/scZkfmmht06zd9DejtLIbZ1LLtEKQDnYXC7IhcNxbQEeqWa2V5oVFeGoUQiry1L3SqG7u9PbyWSqVQhjYb7N6tgm4nT6rtb9+ulvf3K6njwACc2TOAffrHGEMDrc5AF4He3rAmimEoq3nHDsUFWhcOs1Ot6MFfJwpqP7lthy7bfH7msguXGsvWItdEXq3O7CPXtYe1RVattizy5Qbt02xMDtOvbXtmNYq+WYNg9soVDSGUxjla60V3gV+xAtr9Ac4/eIAjQ22tzkAXiWgVS22c6YYQ58/XBzubWeSatPWzNux0JylQ50xjMTSbWPZEXqnMrhZ5tRqmakdrcbSwPKDlhZ7X3LVi28qKmy74pd+70Ap4V1+tvufECSU7dBz1/9mz0HX8MdpTNXw7yUC1rdUZ6CJh28oy1wlYZ8+q4LLO4p0J0Tr1rlufk+J5YX36VGpx1M5ZtkSulQil0uxIWY+6OpW7heUFTeRRxUp0+qzlidMFOjVmKvkwFQxDkYsmBilh40b1Xm2ogBl3qBRrcPgA3qOP4x89jnekSeWnFmbEypXq/teWdbkcJgZFE3m0xT4dfD+07otF5SfXRK+5ZTbZ45cSy5bIQY2a5fLsiFm3dZqtyqWFpQXtNtN9Oxvf0/7umYg8laoPls0VO3Yof/zx48ovbhjK9z6WWs26/NPEzx1Duh4ykUCWSsjBgZZ75QKgA9aacJNJJUFMpWY/m4rmlJRKYcaoVqokkyGBz7ZX6KXCsiRyTdzxuCLo2RC5rpCoXSstH/nygra6Gl0roKxxbZHPFMBsbw8rJF7IzM00lX+1Wg0tw74+oKeL0wdLZEeOMVRKEhSKIMFYu6blXrkIdHaqY97eHlahjCZ1NQt4RmvvQChBjMXqDT2ddBQNlC8UlqVqRY+SusPHXIhct4taTOm3LVw89A3bzCJPJEJyn8l/mk4rYvB9dfPO1MuzGXbsUCR+8rkcK5/+EclT++kcrNFvdbAyOA0lCIIC5o6rMPp6W42XLxArV4bSQ80FQ0PKnTWbRJ7GDE59DelrRRdWSySUxR6t2XO5sSwt8mhgYja+zGhpU603bhH58oMOdjYSeSo1ex+nvrmLxbkrVzQsC7pyBxj92WEO/fQc/tPPkBk4Cq7HQHojvSsN7J5OZTa2Gi9fMLTKKJFQ5z2RUG6R9vYwQ3cqizyauq+NvFpNvS6X65VwY2PqWRfpWggseyKf7fRXn0xdXKlF5MsTnjc5BpLNKmKeS73xZiVSZ70Px0+w7vt/T6bcz/n4emoiBrUaG3NP0jv4LLJWQ1oWwWiu1Xj5IqDPZzarXuvGIFrLP93grYOg2sJ23fpuQZq829pCaeJCpusvSyKP1kaYSXoI9VpS264vedrC8kGjayXa4i36PBMymbBH5IXAf/hRnMIIK/0TFK0sJ9I7QQhM4WMIEI4DhQJGe7bVePki0d0d6sZLJUXmQ0Ohf3uq3puNWnJtABiG4hQd8NTlFqK68oXAsiRyUL5LXX5yJkRPpGnOvtBWC0sPja4VIcIEktlID0FNzZs1npgtgoEhEAYbagfYVH0OxxrfoXEnq7lhHeamjcTueWeLxC8SnZ3qub1dnftsVpVC0G6xqe7zaMMR7XbV2byxWHgNNfYxaGzcfbmwrIlc1xKeDaKaYt1EoKUnX17QZYobfeR6BjZVjZVGJJOh3/WC9OQ9XZDNEHML7Cw+yjrvCDjjwZlUApFKtSzxeYJ2jURnXUGgLOrpBuPGuvW641Ai0TzXRLtVFirLc9mpVvRNmkwqF8tsg1jRKZRthzVaLsbyamFxQRN5lHx1Vqd+fzbQqd25nPKtz9aS1zBvuQnjyDGCSgU8Pyyvt2kDiXf/covA5xltbaFPO59X8lH9PzQfjBuJOgjCfqC+r3ilVlNGX1eXcte0ty9c4bNlZ5FH9Zyz9ZHrkVen44J6XmhtaAvzC9MMz7PGbDL7psKFKles9euI/Zs3Yz3vWoz2LKKzC+v5N7VI/BJB+8O1WqW9XZFyuRy6Rhqh3a1avaLLdlQq6lEqhaSt3TcdHep5Idyyy84i1z4q151dnRUIrXEt/O/qUlKifF4FS1pYHtDJXlFry7YvzIXW1qYyBXM5WLVq7p+31q/Deu89c/9gC3OGDnYmk4qMHUdZ5SMjalmzYKeedUWTgnSaPoRlh1esCMleE/vAwIVdExeDZWeRawtJF7eZrQ9TW+XVqvKZGkarBddyQzOL3HHUjT7XMqTRDM8WFj90Or3uGhSPK2J23ebt3rRiRZO8vm6CIIyRNM7sRkbUdi9UlnoxWHYWuS6W1Ujk0Sj0VJ/To65OxW2l6V8Yiv/wfwi+8e1JyzNf+OwC7E0IHaSKXgez6QrUDImEesxmxtfCwqO3V3Vh2rQJfvYzNaPSzZWnInLdTEIbAFIqXunsnJyL0NOjsnV7e8PWfXPJS7hYLDuLHMJiWTA3i1y7VloSxAvHVCQOkH/z2y/z3tSjmfY7Wo50rojHlZ98MZQxbWF6aPeKnnnpmZiOMzdCE3W0KXsQKPeJ44RGnxZT6PZv2vVyuVvALUsi1ynXjVrQmdCoJ59tin8LIYL7HgCghsnnVvw7xowFbp0SQbOqhdH+jXOFdr+1guKLH7Va6P4A5V6JxUI5arM0fU3g0fZ+uVyYph8NeOrPaxni5a6GuKxcK3p0dJywANZsEPWbRslfR6cvpDDSlYajR+FznwP/qr+vW/50bT8vGv3uAu1VPZpZ5BdT5Ch14mncn4zR//XHMLJFrFfdQezWWy5uJ1u4JHDdUF2yapUKVOtY2HTGWpTIPS807nT/gpER5U4Btb18XgVAlySRCyF+C/hzoEdKuWAVB6JZVTqrczaW+FTrVCpqmy0ir0e5DI8+Ct/73jQrBT63H/ozrvOfu2z7NROaSc2mkp/NhOpPHyZ175eR2TvJZ1bSV3wK97OfB2iR+SKELkMLIYFrY0+7YaPQPnEhQhcMhGn6sdhk33pPj+r81NGhiLxcvnx5KBdN5EKItcCdwIJXv48S+Vz7bjaSuT55V/q0WUrVffzRR+G5GTh5/Xq4+5v3YMvFGSWOxSaf50xm7oFOAO++B3DigoTt4+EgUink+PIWkS8+RONduilzqaTcbVMVyGvs56kNQ99XQfLGGb82+LTarb8fNmyYt58wLebDIv8Y8EHgq/OwrYtCuRxGmKOlJmeaPsHkE2bbYYunKwm5HDz1FOzdO3O6cSwGb3oTbN4cKe/KK6YMdi60aqVZMbRk8sKIXA6PQHsbu3iSCmFhajl8maNcLcwK2qLu7VVulY0blfUciymuaJYBbhiKT6I+ct2MWWvStb88ankPDKig6uWUIV4UkQshXgecllI+JWbQ2ggh3gu8F2DdukuXvZbJhL5tPWWaicijPRxBnSjLmr2PfanC8+DQIXj2WXVxzybSft118JKXhFH6RqTe9Q6KQPDdH07kMBu3v5jUu94xj3t+YWi0yIVQypMLUayIzg5ksUgiZZBgnAXKZURnx/zsbAvzCm2Q6es2Hg9L2cZizWfeOikI1LOe4Y+MKOu7VFIcMTYWEnkqpba1enUYGG0mb5xvzEhVQoh/BfqavPV7wH9GuVVmhJTyk8AnAW688cZLVo5KSw89b3bWdFTwD2G9YU3ky6VwlpQqor5nj9K5jo2pC22m37diBdxyi+oAP9vyvql3vQMWAXE3QqsUNLTM9EJuNOtVd+B+9vNICAtdlytYb3z9PO1tC/MJ7Q/X9/vIiDrv1erUyWBBELZ20z5131f30erVirBtWylX+sYZsqdHLddG4OBgGAy9lJiRyKWUr2i2XAixG9gIaGt8DfCEEOJmKeVlrwEWLZY1NhZ29GhEY2JQ4/9SKos+lYKh4WEefWIff/uZx1m9qpNX33kbO7dvvLQ/ZJ5QqcCBA7B/vyroUyyqx0zEnUzC1q1wzTXKv7dQrasuBeLxeiKP3thzhfaDe/c9gBweQXR2YL3x9S3/+CKGDmzrEteZjCLyRKJ5wmA0QKoJXVdB1EZAY7BcZ/vmcur9qKrlUuKCnQdSyj3ARDl1IcQx4MaFUq3oAINhKBLT/u3o9Ejt5+QTppfpmivVKhSLwzzx1AFijk/vitWM5Qb4xKe/zPvuecOiI3PfV/7sZ56BM2eUcairN85E3KYJa9aoQOXu3arOzOXMSLucaLzpmunK54LYrbe0iHsJQZ/vTCa0yEHdL1HXiYbrhrGyaPGsSkWtr2uTw+RMzvPnVc/QU6cuT5bnsvECN0oPo37xWk2dIXWiPNTPrp9P6wPt++pEHTp6jJhjE4vFERRoy6qQ9Lfuf3BBiVxKpVU9cACOHFGzDzXwzL49XXe3usjWr4ft2y/MR7yUMJIvc2owR77gUamuQOXBmSQScwt06u2UKi7JuM2a7iwdmVad46WCaPKX9o9rdZpEpXL6voi42syJ5u3aGIzyRNSSL5fDcg+6nK1WseRyqiTApcS8EbmUcsN8betCEC0nqlUnngdS+vi+B6alup/jg29hmj5gTvKR65rkuVyBTDqLEGUMQ40SmXSS02cGLuvvqlaVlb13rxrla7UwmDvbOh+plPLh9fSoaP2GDReun15qGMmX2X9yENsySSZMwgmKTzJp4osKe46OTZBzNhkjV6pOImu9HT+QlGsuQ7kSpwZy7Fzfw/re9oX7gS3MGtEZWCqlyDiVglyujFsLwHQwTYEP4IMQPkGgWF3zhBBQc6t8+WuPMjIyREd7L9fs2kRb24oJIu/qUkSuDclz55YQkS80dFUyUAQXBip9MI0J+9sE/PF/Gqc7UVLPZlNUKiWSqRrCUIyZL5RYvarnkv0G31cXwP79qkFsPh/WQNZyytkEX21bBSm7upT1vXmzIvLl5O+eLU4N5vADSbFQxvUDAj/sSD9cyPHk0TN0tyVJJx1yxSrHzo7SlU2QTjrUXJ/HD57F8zxKVQ+JxDIMEjEbxzJxfZ+9xwfIJmMty3yRYu++o3zr/gfpP5dmxYoqr3nVTRhsnOi1aZowPDICsgPp6Vx9RQSVaplK1cc0TEolF0gyOlqlUh1DGlW6u7KUy/CDH+/FiQWsWqUinnrAGB5W99/laMq8bIgcwqmMLk2pxPsBYKLdX77vYZoxfCSm6VKrBeN1qgXVao3RsSKDw1VWrW5n/8GDICBjBeQLJcZyRd72JiXS0RfI6TMDrF7VM+dAqJTKHXTmjKrKdv68IvJaTVnh1ersy/Aahsom6+xUroKVKxV5t7cvX3/3bDGSK1OsunieT7nm1ZcetWtUPZ+hfAnbMqnUXCxLUK55ZFIxylWX0XwZ01TzZ19KPM/HMg3ijo1tmVRrPqcGcy0iX4TYu+8on/j0l2nLpuhoX8XRE0/w67/9Q67a9EI2rE8Rs14BJcHA4Dm6OzvwAxdhmkSl1K5XwzdMZODj2D4SC98XjAwb1Nx++lbsBFwe/9k+7nh5vbhvcBCuuko9a5njpcKyI3LdhgnGBfwoItenpubVSJgJJFAb1yd6vkQGPp5fxrYDPN/jwMFjrOpr56ln9nPw6I/pW9HFL7/9LnZu31h3gazs62IsV5gxEFqtKvXIwYNw9qwi8SAIW0bVarMnbiHUlLC7Wz13dIQ+7+Xu754rXN+nXHWpeeMHNjKjMSwfKSWuF3B+NE+xOu4QFS4xy2S4UAYkrieRhLOhfNnFDySWaRCzTUqVKyxrbBGimWH1rfsfpC2bIpvJMjxUYv/BowghGB7J0T9wlNUrtmDbKQaGj9LduQMpJVLKceNHqjK31RKmaWMaNn6gjEDDtBAixoOPPMSLb+2gpyfO8PBYXVBTW+J6Fnz+PKxde+l+/7Ig8mixrGIxzNQKAg/XdbEtE4Q6ojII7+QgkAghkEGAEAae71EsFSiVhxgcOcJTz+TYuX0Dr3vNyykUx3jge4+waePqyAWSBozx5zAQ6vsqCHn6tEpvP38+TO3Vlra2vhvrGk+FREIRdiajXEhdXUrLumrVlePvnitG8mWKlZDENzzxEPv9zYADSLY/+h26T/Wz/0V3hEQPICWnhqbv4VZzfVwvIBmzScYvUv7SwkUhalhZlsV3f/g4X/rq9zBNg1tv3kU2E/DMc48yPJKjUnGplgeJxUt0txWRgUO5UsALfAxhABJJgACEELh+BV96pOIJXLeKacQQmFhmguGRs+zdv4+dYi29KzoplUJDqqNDEbnrhklClxLLgsijihVN4seOncfz2iLEHQAGMlIHRAb++JApAYNKNY/nlgmCgGKpTLVaZM+zRxgYkgwOneb0mQF+8OMnyGRS3HrT1WTSPXhuJ/m8x9BQB+fOdPCH/98AqWQWx4lNNGrVsibXnX39F8cJidtx1Gyjp0dJBXt6Lk+22FKGDk4G42b0hiceYvtD3+fbO3SikqSrcpZNe34GwP4X3TGn7XuBJGYrN8ya7sVTqne5QVca1NUG9evosnu/fpy29E68ms++A08hgPb2DOf6h3j4sb1s27qOfQefwfd8qjWXfOERsuk+VvaUse0OHDuNWysSi4XV8QIkAoMg8CiWhkjFO/H8KoH0qdUKCMOk5hYpFPIcPznIHbe/gNHRkMj1/Tk0pOJVR4/WV1KcbywLIo+OduUyDA7leO7ASXo621AjLIBEAL4M1EmSIPGRvo9EjmduVQlkgBQSw1AKh8PHHmP/4TKpZIxEIo70TQKvncOH0jjGVRh0UC6WSSdNLDOGLy1Gc0VSCQiC2IR8aSYYhrK602kVqLVtdVFkMqELxbJU4PPcuTDz1LbVQ/9/pfvENU4N5rAtE8tQB2T7Q98fL76vD1BAZ+0MAti45/E5EzlAzQvIJIyWf3waaBVYIwFHHxcKXcRqeDhPR0ea/QeOk4hniDlduO4YhYKgUvb56cMnicdW4ZoVKrUzyMAjX+inVBkhm15JLJamWB7CcVJq4i4lAkFAgFur4PlVECCEATKgUBzAC2oYhgVIYnYat2byf/7f9zk/+OSEa8c2N9ZlfQ4NhY2g5xvLgshLJTUC7t13lM/+32cZHOxCSDU0ekENU8YwhZ7+BgS+h2lYuG5FnZzxk3Tm3NOAxLRj6Bu+5qpRwjJXcPWWn6ctuxLLssb923EMIRDCRAgbx2kH38LAoFwWExLHZhBCEXdbW1inIZFQ6cJtbeqE65Ryz6t/zFYv3gymGRL/VM/LYTAoVVwSMQvLNAGPxp/UVjlDm6vkBBf6cy93O6+Fgk6um+ox2+YrusuObuagr2ddjEq/F619pJdH5cF6PT3DTSd3US4FZFPddLU7CGFRc6tkU1XULFwghJqND48eY2DoILn8WUbHTpNN9pHPnyMRy+ClejFNG6SPEBaGMHFiKSpuHi9wqbllbCtO1S0j8cmkV2CYBoI4X/jKv1IujyDFWaq1Gp/49Jd5+1vfTCa1duK3tIh8Fjg/eJp//vKXGRxoI+60EbNUiDjwfXXWJ1RFYty9YiEJ8AOXYnGI9uwaVGjUxy2P4nllsumV3HDN2xFC4NgpUqleECY1zyVmJ7HNDIYRI+bYRL5gAiojLNSr27ayuNPpMMkgHldWdzarXCm6RoMOgM4F2g+vO57oGyT6gLA7ePQxF0LS6cnTDQgLRXAj+TKHTg8xtu8gzsF9rM/n6E6Hro+O4lFGkhvYXHgScx7K7ZarLnuO9i/qJKHZuCem+6y+pqKf1a91gD6qs44Se/Sz0Wd9W+rnxsGg0QUR3b7eL/2ciK1iLFcgmcgihIEQBqYBybihZuNSImVArVbE6k7S0baOk2ce5+z5Zzhw5Dus7ruGZLKdRKIT24oTSF+p3gIfU5gYwiTwPWzTIZAeyUQ7o7mTdGTXUyrnKJTyJJM+mXSSgRGP/QdPsG3rOn7w4x9y1yt/kUpFWeVnz17oGZwZy4bIH3z4IdqyKQ7sHySdXEFvz0rkuJtEQQJKneJ7NUzToVwdwxAm5wf3YVkOmojPnn+GdKqHtatuwDRskokukok2bDuDaaqRWl1QxniSETCJEySYPuDS0ZGeSFDSJKer7qVSYRaqThuOEmzUFz6d7h3CG0oPEtFtNN4IUyE6GEz10DdRswGh8dH4ezQMY+aZwVz9iSP5Mo8fOMP1H/6vmNQPq/oq+KUTf8So1aVmYvOAQsXj6Jlh+joz1Fyf/ScH2ba2e17JfC7uiajF27iOXlarhTM9vbxxwI9ax9HXzdaLXn+NRkN0uX4d/V3zAwtDtOPYeoMCq24irIjcsVMEgUcgfbLplazsvYZHn/wMAIXSIL7vkU2vpOYViYssUgZYdgLHTuB5lXGrXpJItFOujNCeXcOREz+kWByhVK5hWi6mIYnFHM6eGyI2rkLQRbYeeewcX7vvJ5w9d/KCJMvTH4EljL37jvLNb/+U8/0pnnzmh1x3zWaOHN/Dti3rkUhct4KBwA9cTFMpFarVPDE7iZQ+g0OHSMbbqblF8vlzDI8dw3XLbFr7QtavvZlErAPHSSBEIy00YILEJVL6BNLDNAxMDHzfplAIrV7LCuvCNBbrimIm4hWi/r1mr5sR/1TvRWsuz4bsGz8X/V/Xt4l+V2M9G71sqtlB9LX2/+tKhfrZNNWA6DjKDfXssWF2fOIvqZoZbFnDCmqM22QTZ08AHd5Q3T7Xxm+4bT96gE17Hp/0e33g/vf/7pTHo1TzOXpulPZ0jEwi1lRXPpV7QucN6Nfayo3mE2iJqnZFaOKNWstTkXGzGdlygbqmPDzfRwaokytVVEyOz8LHFxH4Pq5XpFotUiwPIQMfKSRXbXoF/YP7OHD4X4nH21mz8vpx0rawDVvlDwQew6PHyab78LwaiXgbfuDi+TUKpSH8oEzg1iiOHcbzfVb2djNcLPOC51/DufNneejr+9l/6EcUCytYvaqTVatmJ1meC5YskU9IjjLddLRnsR2Thx/bixDqpnTdMmVjDJDU3BJ2NonrVal5JTzfA7fMWP4M5eoY54cOUHPLdLZtYOO6W0mn+3CsGLPtTW2a4PkujPviTOGMW+pBXbEdqFetTEfk0910U7233G7UucEH+nh0y5/XLb376H9nc+UgoG7oxjHKB5556asnSLzZGGYCr/j4R7j/vf+JIBD4noHvmrg1C7dq4lcdAtfktG/heSb4Fg+IkLz1LOfKPj/zD3U8BQKLiQmWPoGNt64JjpMkleyhs2PDxGI/qLF25Q30r7yeXOEse/Z/lZ7OraQqo2Qzq1SgU0qqtTyB7MF3S6RSXVSreXKFswSBh2kaDAwfH2//ZnK2f5Ce7g62bV3PV77+VbLpDYzlCvi+xekzEIt79HS3A/NXu2nJErnWcqeT3QSBYPfOzXzvR4/jOCrImS+cw/dc4oksnlsFJEHgUqnmcZMVhIDhkaOM5U6zc9tddLVvZGXvbixzNulX424adJEdC8sE35dgSlUGAAm+ANOctVa8daNfKKb2dXfX6rvg5to7SeXHQEIx08bBW15E/5YdXPO9b01Yb/9j69+A1cQ18uD87nULF4+5tHNsBsOwSCW7WLf6ZtxaicHRIwwNH+b80EHWr7qBRLwdx0lz/uRPKFfGCAKPRKIDy4pPbGNo9BSOIzFNg3K5ShAE/Mb738b+g8fJZgXJeJxSySOVzGPb3Rw8fJKe7vZ5rd20ZIn89JkBVvZ1ARUMAnq62+loSyN9FdgqV0YRwiQez+J6JTy/yqFjP2I0d5K4k8a1U1SqY1y15VVs23gniUSGycO4IuwgCAikh5IwmuNKF+UrEOPyNsMwQagQ/ETZS1NV2Yu6K1pkPT+wLOVWSadrBHYeIzVGoq3C3Z/6kwl3SjP8+Bfe23S5OT7aSgCjlWG1VGCaF0fmAgPLimNZMRw7ieMkacuspFQe5tn936CzcwOmsPC9KsOjx7Ash6MnH8SxExiGRSyWpVw+x8b1qwiCgN6eTnbt3MRdr34h3/3BY6zs60KIM6RSNpVqBckhCgXlW53P2k1LlshXr+phLFegLWuAoXq6jY4VyKTUNMX3XQqlQTrbN1CujHLyzM84dfZRTNNhcOQwucJ5erp2cPXWO0nEs+g5mfJx+4gJ+0ypUUzDGveVRzBR0tKPLHCaJuu0CLw5tD/ctsMAcEeHSmdeuVKpebQfvNF3r5N+hvNliuUavtSJHJNdKJLp7HbwLQvDrWEAv7n/Ho45WzjYdhODiXWMOd1UjBTSTDbZcgvLAwLDMDBNG8OMUanmsawYI6NHsawkiUQHUgZUKmOMjB6nq3Mz6eQKUokuKrWzXLNrM/F4jLFckXf+wmuBKEel2bp5DY/9bB+1mksmnWQsV6ir3XSxWLJE/uo7b+MTn/4yoMrL5gslRkbzZFbF8PwqI6PHScSVTChX6Gff4fup1Qq0Z9fQ2b6R7ZteQSrVBVgTfrCAACEFpmkjhEAIi9neuFq5Mp12fLlDiFClorunZDKqeFdbmyLojg5F1vG4ClA2I+jZQif9xGyT4nh/1vvf/7vc+fEPTzoDMwUsT2zfXecj31A7xIaBQxPvS+BbU3zeEBB3bPras5QrAas6ujACk5ERVQFvdFRlH+us42jQcrYa7Bamgolp+hftYlFCBYlhWCQTbazqvYaujo2c63+Gaq3IaOE0heIAxdIAne0bMIRBZ8cGxvKnKJRO8Z0fPMYdL7u5LngZ5aiuzja2bV3H/oMnaGtL05ZN87Y33dlSrezcvpH33fOGiUI5hWIZ1/WwrBieV6Fay9OWXc35wf3ki/1UKiP09VxNPN5OMvH/t3eusXFc1x3/nXnse/kWJVrU05VkW/FLlV0ZqYO4SQPXCRwUbQAnaZM2QNMGjpsCLgInQb/0U9B3ivaLkbhAWyFtkaZp0LRubMQtUKS2o7h1bElWkCiyRZmiSInUcrncnZ2Z0w+zQw5XpPjYlZbL3h9ALHd2ZvbcmZ3/3HvmnHMHsCyb0uwlRIR8doBMthdbouI4S4mvtOXcLosKlCyPezNJRogkBTEZpRCLxUqjgmR0SXPUSizMcYRIoRD1kovFxbovmcziazodvd6MWPI46aeYS3NldnE6pOsJ9kqcefDn6Z14m8FL48v25mf6B5fbDNcRsq6DZVtksxYHdvXRX8wsu+5yxLOyxxEr8/OLf5VK9Dc3F71Wq4vRLMkIl60enbI6dtMoeK3C3sj5VgVCat48dX+OSmWaMKxTrlzGC6qI7VCrlZmYPAnAXGUKS2xGR47gOA7ptMvI9iG2DfUvEeZmjdq/dyePf/JDN2Rimq4VcogOVHxQfumjT1Es5lD1uVoaZ2LyNKMjP81M5TzTM28CUCpfxHHSlCtTFHJDFAvbsRrukmr1anRHFoeQKMTAtTM4bhrbcpGFePRFl0u0bDGwzUYJAsVO/Kqap5lLvjZ/1pycEydXJKMekhdtklisk7745Pckp7yLQ/jS6UiAc7nFRKV8Puolx6n/sTDHIt3JZJ9mchkXrx5E9cFtixohYbhxFXv5lz/O3lde5NBCOn/ETP8g//3h31h2G8eyCEI4vG9jE0yIRMc7LnzW37+27ZJJPnHymOdFN4BqdWkCTjwRSTwZyXIJQcuFL8bfE/+WktmV7WQ94a5rI+ql+0HY0OqooqEiUXVDooJ60SeNK7oh5o6VIpeL3ChB6BOGHtVaiXyun907f4ZY/C9OniKdyhOGMDc3z/kLl3j99NlrLElq1I2kq4U8ycVLl9kxPMDZcy9Q9aKHCTWvzJWr5yhXptgzeozBvn2IZbN96DZy2QFELILAIwx9ql6JWJRTTjZK0U3bqK94Og+a+EUjUaU0icKNIt+a06hxDsWiuyROOv4/mayRTLSAlRMvkiRTl5sTbeI462x2sYccJxxls0trscQPCpt70N3G6FAPr52d4HLJJ1DFlsT4SRau4es8+ryWc0eOce7IsTWvLyIM9mQpVWrr+JbWsazo3K23xnV8A4iFPzkaqNWWz8KM49rj2amS28YJRXGC0UpJQ/F3w/XDbmOSuQfx+3hfzaPPa9sY4Pt1xAaCqCx1GISEgKgSqE8Y+FF1Q79KzSsjYlGrzYJYaOjjeRUCfEINsSyXQn6Yer2KV5/D92sEgcf58RMMDRxs1FyB82MTnHrjJx2ZCnLLCPmO4UFKs3PR02EPbNvi5JlvohpiWy6WOLhuDq9e4fzbr+A66ehJdaqAZdlY4pDN9EbrNk6MH9Qg8KIMUQ1RgqianiqplIsgqLpkszkCBPWVdD63bK8GlvZ6YPE1mQGZrIUS/x/3jLPZRXGOlyezNlfqQW9l4pJoKdfGqwdYKJYlWLYQBlFSiH8D/NC2FUUwpxybQi7VNTXJN3oDSGaFNt8I4vDaldLxk88D4hFCc4JTsvZKc+fmetmk8XrJa0o1AFui5yS2RsmBXo164KOhUvfnCcMaWBaC1bj+LfK5fsAiDKN0/vn5K4ShT3lukppXbhw/h1TKYWjgVsYvvY4lQjpVpFBIc+jA7o7N6dv1Qh4XlAcYe/sSxUIuqnrn5CnmR3HdHLadYrBvL9VaidnyxUiggZHtd2GJhV+vEhJiWfbCiQ3DYCEmPdQ6rp0hm+nDdVOI2Di2k0jrF2wsrJRD0HCUJ0U5dlXE2YjJTMV8ftG/HKenN0e9xBEdSZF2uv7Mtc7YVIlCNs1AT3TA5mt1pmYqzNW8aPBkWY10kaAtYh471SASjEzawXXsqC75Fq9JnsykXQ9x3f3m3r/nLQ0bjEej17sJwKJoN2fKxjeBeh3efGsax14MIXVssJ2o6mlAgOM4gBuJs4BtOViWTd33AB8Ri2JxB24qS61WplSeYK4yRWl2nLo/j+vYBGGIY9u89fb36CnmePCB+xkc6L3pc/outLEj39omkgXlj913mGJ+iPGLSm9BqNfrqCrluUlmSmO8PfEyYRhSry9m57x14WVSqQL7d/8sPYUdeF6ZQmE7KTcfVT5zo5mEIn+akkqlQHzqXolUyuXgwX1YViTUsUCLRMKczS4tMbtcSGIs0MketBHotRM/7IzJpl1Gh3u4ND0Xzc8ZKmnXJpd2uXB5lmCD/vOUY2GJ4DVSwS2BtOtgW0I25VD3A/aPrNG5/f8M246uhew6Ss/EQr1c779WW94dk+zVT145x/y8TyaTQ9UCtQiqs1yZucz01Rnqfo26V6dajUIMM6k8biqPhgGWncZ10lS9clSXJfTx/SqF/BCKz/TMeYrFPHOVKsVCjkIhy4MP3MO2oT6ulso3dE7f69HVspGcqcev7+S2g7vYPVqlWPR45OE7+Zuvfovn/+MM+ZyLH9jUatGszHXfZ9vgQbKZPkaG30E+N7Swz3hIBQF1v0YY+Fi2YlkBPT0ZPG+eUOsMj+zm6NHlh1Cp1LU9aDMRRPuJH3am3MWDW/dDRGC4L79kuT1dJggVS6BZz2N3a7M+iBCJddploJjl4OggY1Mlpkvz1IMA13Eo5tObsuJhN5Osr7Me4uJiYu3gr7/6bbB6yWULVOYDFJ+fe+gQx//+WQI/IJ1J4bgVzr55As/zSaezDA1sY26+Tt2DUnkC368u2DPQ34tqyE/tHyWfz1IqzZFOuxw6sIfBgZ62x4Wvl64W8ji7UwQc9yIiAUUnZPziZQ7f/ihf/P1PL5nL7/VTP+aeuw7x8onXmZnxqFavcubHz5FOFSjmt+H5Uep+Jp1i29AA9aBGreZTq3kM9BeZLSvluats317g4fe9iwMHjEB3ktGhHs6cj2qKu45F3Q+p+5HAus7ScNFsysEPvMWCXTSG6UQ9biS6CcQib1uysA/bEnIZl/5i1gj2JiYuFX3vPXtIZ97buO5/yM5btvHYh6JKgyffeIHvvvQD0pInk3W5ZWSIC+OTjOzoo683jU5VmK7NkkkJNXGI5/AUgd7eAodv34/j2PzmJ34RYMk8oe2MC18vXS3kycwpkcjh1pz2mgz/+eM/P87VUplj99/Jf734KhcujIEIYdjT2N6jfyBNaXaG/nAb/f15tg/3Mz1d5SdvngbgyD2H+NhH3s8dt+256e01LKW/mOXQriHGpkoL9cD3j/QzNlW6pqfem8/gh1E4WrUeTbps27Bvez/1ICQIlSuzFapeQBDGIWqKY0cJR2Y6t+5ipbC/j33kEcYnppi6fJXZ2Tmy2QyHb7+Vg7eOMjE5TWl2jnc/eC97d49w7q3xhQQeFHp68uzfe8uS8rOdEu5mRDuQPXD06FE9ceJEy/tJ+sjj7M6rpbkVS0Mm169W6zz7/HeZna2Qy2bYvWsHd995gFTKobenwJO//dGW7TN0hjh133XsJT31kcEipUrtmkkgpmfnF1wmlZpHzY/8665tMdyf58DOQdMT30IkR+nNdcGv99lmQES+r6pHr1nezUIO6z/wyfVTKZfxiSn27NqxphuBoXuIxXkzz9xjMKyXGybkIvIE8DhROYtvqepnV9umnULeKpv9DmwwGAwxKwl5Sz5yEXkI+CBwt6rWRGS4lf11gpuVQmswGAw3ilYnLvwU8EVVrQGo6qVV1jcYDAZDm2lVyA8CD4rISyLynyJy30orisgnReSEiJyYnOxM9pPBYDBsRVZ1rYjI88COZT76QmP7AeAYcB/wDyKyX5dxvKvq08DTEPnIWzHaYDAYDIusKuSq+t6VPhORTwFfbwj3yyISAkOA6XIbDAbDTaJV18o3gIcAROQgkAKmWtynwWAwGNZBS+GHIpICngHuATzgd1X1O2vYbhJ4cw1fMcTWujFspfZspbaAac9mZyu1p5W27FHVaypzdSQhaK2IyInlYia7la3Unq3UFjDt2exspfbciLa06loxGAwGQ4cxQm4wGAxdzmYX8qc7bUCb2Urt2UptAdOezc5Wak/b27KpfeQGg8FgWJ3N3iM3GAwGwyoYITcYDIYupyuEXESeEJE3ROSkiPxBp+1pFRF5UkRURIZWX3vzIiJ/2DgvPxCRfxKRvk7btBFE5GEROSMiPxKRpzptz0YRkV0i8oKInGpcK5/ptE3tQERsEfkfEfmXTtvSKiLSJyJfa1w3p0XkgXbsd9MLeVOp3MPAH3XYpJYQkV3A+4C3Om1LG3gOeIeq3gX8EPhch+1ZNyJiA38J/AJwB/BhEbmjs1ZtGB94UlXvIKp/9HgXtyXJZ4DTnTaiTXwJeFZVbwPupk3t2vRCztYrlfunwGe5dtL2rkNVv62qfuPti8BoJ+3ZIPcDP1LVs6rqAX9H1HHoOlR1XFVfafw/SyQSOztrVWuIyCjwfuDLnbalVUSkF3gX8BUAVfVUdaYd++4GIV9zqdzNjoh8ELigqq922pYbwCeAf+u0ERtgJ3A+8X6MLhc/ABHZC9wLvNRhU1rlz4g6PmGH7WgH+4gKCv5Vw1X0ZRHJt2PHLc0Q1C7aVSp3M7BKWz5P5FbpGq7XHlX958Y6XyAa1h+/mbYZlkdECsA/Ar+jqqVO27NRROQDwCVV/b6IvLvD5rQDBzgCPKGqL4nIl4CngN9rx447zlYqlbtSW0TkTqI78qsiApEb4hURuV9VL95EE9fF9c4NgIj8GvAB4D2b9ea6CheAXYn3o41lXYmIuEQiflxVv95pe1rkncCjIvIIkAF6RORvVfVXOmzXRhkDxlQ1HiV9jUjIW6YbXCvfYAuUylXV11R1WFX3qupeopN6ZDOL+GqIyMNEw95HVbXSaXs2yPeAAyKyr1HN8zHgmx22aUNI1EP4CnBaVf+k0/a0iqp+TlVHG9fLY8B3uljEaVzr50XkUGPRe4BT7dj3puiRr8IzwDMi8jpRqdyPd2nPbyvyF0AaeK4xynhRVX+rsyatD1X1ReTTwL8DNvCMqp7ssFkb5Z3ArwKvicj/NpZ9XlX/tXMmGZp4Ajje6DScBX69HTs1KfoGg8HQ5XSDa8VgMBgM18EIucFgMHQ5RsgNBoOhyzFCbjAYDF2OEXKDwWDocoyQGwwGQ5djhNxgMBi6nP8DtZMrLpbuHz8AAAAASUVORK5CYII=",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_ot_map(neural_dual, data_source, data_target, inverse=False)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAACQVElEQVR4nOz9eZwcV3X3j79vVfXes89oZrSM9sWyZMmrbGMb4w0bOyZgDA5gIIRAgCwkPD9Cki/ZeWKyQEgCD3aCQ8AQjAMGB++sxpss2dZiy9rXGWk0+9Z7Vd3fH7fvdE1P9yyakWZGqvfrVdM91dXV1dVVn3vuueeeI6SU+Pj4+Pic/RgzfQA+Pj4+PmcGX/B9fHx8zhF8wffx8fE5R/AF38fHx+ccwRd8Hx8fn3MEa6YPYCzq6+vlkiVLZvowfHx8fOYML7/8cpeUsqHUa7Na8JcsWcLWrVtn+jB8fHx85gxCiCPlXvNdOj4+Pj7nCL7g+/j4+Jwj+ILv4+Pjc47gC76Pj4/POYIv+D4+Pj7nCLM6SsfHpxyD7/8tyGQKK0IhKh74+swdkI/PHMAXfJ85xyixB8hkGLzz/YX/a2upuPdfzuyB+fjMcnzB95l75MX+KC20cLT0Nj09DH7s96GnZ9RLFQ89cDqPzsdn1uL78H3mLP9z3t/wxfP+c/j/h+f/Pl86798LG5QQe2BkT8DH5xzCF3yfOcvSvi2AwbcWfw6ANb2/QmLxHyv+fsR2D8/7be5b+nczcIQ+PrMLX/B95h6hEADvOPFvQI7O6Ar2Bs/nvNSrVGdOMBBo4Lna24Y3b61cz1B4Pq8H1k5o95kXNpP4i79l6Pc+TeIv/pbMC5tPx7fw8Tnj+ILvM+eoeODrw6L/+2/8NgA/Xv4Zslh86OCfIqTN5sZ3cDy0FIDf2//7ADy54o/H3Xfmhc3kHvguMpGA6ipkIkHuge/6ou9zVuALvs+cpOKBr1Px0APUPPQtbrzZBODeVf+CAby17d8BwY9afp8hsxIA0xkC4Ce17xhzv/YTT0MkjIzF2G+sQsRiEAmr9T4+cxxf8H3mPFdeCbFKk1ywksfvfIBN9/8+9Y1BUqEGnpj/EXIixO/v/SQAOxp/neh3y0fpyJ5eiEQAaGcBg8QhElHrfXzmOL7g+5wVfPrT6vGNN2D/fvjN34RgEFprLmLLb32d+IMPEI6ZYJp8fYz5WaK2BlKp4RvjGIshlVLrJ4g/BuAzW/EF3+esQAh4fz7a8jvfAdeFK65Q63fuhJdfhj/8Q/X6iRPQ3196P9bNN0IqjUwksGSWrlwVpNJq/QTwxwB8ZjNTFnwhxGohxDbPMiCE+FTRNtcKIfo92/z5VD/Xx6eY5cthwQKQEu6/H66+GmprIZmEV16BffsgFlPb/sd/qO2KCV2xicD770LEYswf3IsIBAi8/y5CV2ya0DHoMYC+2EISRqU/BuAzq5iy4Esp90gpN0opNwIXA0ng4RKb/kpvJ6X866l+ro9PKT7yEfXY2wv/+79wxx3KJd/drSz9d+THbIeG4LXXSu8jdMUmYn/1/7HiCx8ncM3V5NZPTOyhMAbwGhdwnIVqpT8G4DNLmG6XzvXAASll2RJbPj6nm49/XD1u3w7t7bBqlXLtnDwJu3cPj8nyv/87OiWPl2BQPR47NvHPFrU1yFQagHo61MpJjgH4+Jwuplvw7wL+u8xrVwghtgshHhdCnF9uB0KIjwohtgohtnZ2dk7z4fmcC8ybB+vWqedPPgmXXAJVVZBOK//9RRep13I5ePbZsfcVCkFX18Q/27r5RvozYWQuS7XbpXz5kxgD8PE5nUyb4AshgsDtwEMlXn4FWCyl3AD8K/DDcvuRUt4npbxESnlJQ0PJwus+PuNyxx3KQk+l4Ikn4E1vgkAA+vqU1R8IqO2ee05Z/uWor4dstrS/vxShKzbRce17EYEgoq8fEYtNagzAx+d0Mp0W/i3AK1LKUbePlHJASjmUf/4YEBBC1E/jZ/v4jOL3fk89trbC4cPQ0qKEe3AQFi1Sr0kJv/wlOE7pfTQ1qcehoYl/bmLeCmI3XE38X/+J2F/9f77Y+8waplPwf4My7hwhRJMQQuSfX5b/3O5p/Gwfn1HE4ypSB9SAbWWlitJJpZTAG/mr/4031FKKSjVRd8xeQDHZrIoO8vGZbUyL4AshYsCNwA88635HCPE7+X/fBbwmhNgO/Atwl5QT7ST7+Jw6112nRNtxYO9eaG5W671WPsALL8DAwOj3G4Ya5C2TaXkUqZR6bGyc2nH7+JwOpkXwpZQJKWWdlLLfs+5rUsqv5Z//m5TyfCnlBinl5VLK56fjc318JsLHP66idLq7lWumsVE1AF43zvHjsGdPaV99ZaV630RMlO7uwnt8fGYb/kxbn7OecBhuvVU9P3RIzcINBJQ17rXyn31WvV5MXZ0S+3Kzc7309IBpFkI6fXxmE36JQ59zgosvhq1bVYROe7sK00wkVGimZmBAxenPn68aCU11tXrs7Cw8L0cyCRUV0330U8M+chRn8xbczm6MhjrMTZdiLW6Z6cPymQF8C9/nnOHDHwbLUqIMyjdv2wW/PsC2bcrX7yUSUYO9icTY+8/l1L71xK7ZgH3kKLlHHkUmEoj6WpXb55FHsY+UqQXsc1bjC77POUMgAO9+t3p+4oTy4dv2yG1yOdiyBdraRq6vqFCC77rl99/To1w/s8l/72zegojHGArWMWRHELEYIh7D2bxlpg/NZwbwBd/nnGLlSlixQj0fHFQWfyoF3jl+J06otAxed08opGbq9o6REke/NlsEf2gI9rbG2Ocspz1VQ38uql6IRnE7/ajocxFf8H3OOd7zHuV2yWSUhe+6hZh8UJb//v3wzDOFdTqufqxsH66regLx+Ok57omQSCiX1O7dasKZqKykQZ5kRcUJFkbzsaXJJEZD3cwdpM+M4Qu+zzmHZcFddymR7+tTj4ODaiBX09+vXtu1S/1fVaWEvJxLx3VVTyEWG9l4nAlSKdVA7d6tEr25ruqxrF4Na25cSFX6JCQTSCmRiQRyKIG56dIze5A+swI/SsfnnKSlBdavV66bUgOtrqss5KoqJZ4NDQU/vuOo0EsvfX0FwT8TZDJq7oA322ddncr9o+a0K6zFLXD7rSOidKzrrvWjdM5RfMH3OWf5tV+DI0eUWJumssy90Tj9/Ur0TVNl3DQM5Rfv7lYZOb10dyvx1UnZTgfZrBpf0LN5AWpqVGM0Vq/CWtziC7wP4Au+zzmMaSp//te/rgZkKypGDtRKqVIj19WpCVkVFcry7+kZLfiOo3oJ0z1gm8upeQPekNCqKvX5xb0MH5/x8AXf55ymqQkuvRRefLHg2gkECsKfTKrEaaapJmSVGpDVKRfi8ekRfNtWIu/N0FlZqUTe8u9YnyngXz4+5zw33AAHDyphz2aVuGvBl1JF5jQ0KAFube3l4NHX6frGNhbMr+GWm65k0YKlZLMqnUIodGrH4Djq870J3OJx1SD5Iu8zXfiX0llO72CK1q4Bkukc0XCAhfWV1FTMoqmgswDDgDvvhPvuU374ykol+Dq5mm2r6JdkqpdtO/cRCtbS2LCE/oE27r3/Yd55228QCjZPOn+O60JHhxpD0ESjSuTPxVw8U71W/Wt9fHzBn2OUuqiBkhd672CKPce6CFgmkZBFNuew51gXqxfV+zdCEXV1cNVV8ItfqBDNioqCte260Ntrc+DwCTIZg0wwg+PUUFuvsqm98NIerr6iufzOPbiuGhfwplsOh1V6h1PtHZwNFF+rg4kML3a2EgsFqKmMjCve/rU+MXzBn0OUuqh3HGxHIIhFgqMu9NauAQKWSTCgRvf0Y2vXgH8TlOCqq+DAARW5k0opy1/F3TvkclmEjAL99A8McuDwAWx3J+vPX0k62U80Wn7ClZQqisdbGzcYHJ2k7Wwl9ciPsR9/Wo08x2JYt9xI5PbbRmzjvVbTGZv+ZAaQZB1nQuLtX+sTwxf8OYS+qJPpLD1DaRzHxXEl4YBJTaW6qL0XejKdIxIa+RMHLINkOjdq3z4qfv0d74B//3elTdFoIdEaZpBoZB5DiV6k65DNZujsTZB55QDrzmvkiZ/8gq6e/cxvjnDLTVdy3uql9PSMnJkbCCiRn03J1U43qUd+jP29hyEYGJ7ebH/vYVIwQvS91+pAMoNpCAxhYDvu8DW9r62bSChQ0mXjX+sTwxf8OUQynSOTzdHZn0IYIKXEcSWJjM2hE70ELR2nJ7FdQLrYriRgmoSDJhXREKZhEA2fxmDxOU5VFbz5zfDUU5BMOkAOxwkAAtvOEQjGSaYHaKg7n5Nd+0kMBdjx2kF27zlCR88O5jcuYf9+g3e+PcKiBU0YBixYcOYmZM027MefxgkGyIYqiJAd9lvZjz8NHsGPhgNkcw7BgEkqkyPnOORsV7XCXQNYhqA/kSUSsggGTGzHZU+qi+a6CgaSGXoHU/QnBDXxCOG88Ods17/Wi/AFfw4RDQc40TMISHI5ibcAUzJjk8yMTP0oAAlkcHBdi3TWIR4JcEFz0xk86rnFrt2H+OXzz9PZfT7x6Gocx8ayTIQwCASixKMNZLNDDCa6qKpYQDaXxJUm2RzEoytVTv29R/jbf/h7QkFlXV60cTUfeO+trF2zdIa/3fRg22qgua9PTU4bHFQ9olLpo3Ox90GlCQJu7/+mWhkIjNp4YX0le451kcnaJNK5wrUtJX1DhenElu2Qydr0DKSQEg6fVOMolgESwVAqS0NVlGg4SM52WNZcM+3ffy7jC/4sp3cwxf62broHU9iOJJ11xn9THul5TGRsYqEA0VDA92mWYdfuQ9x7/8PYtsOvnv8Z113+58QrmvKiH0QIA4kkFm0gkeqhqnIhnd17GUp0EgxECYfiZLIJnEwWIappqK1HAi+/kmDHzh+ycf1ybrr+UpYtbcYwCrN7vc/LPXrTJUw3uZzK9NnXp8R7YEDp8bA7axJEo6o3U1mplsAzT1GZPEHA627J5Ya7PN4gBNt2aOseYqxKkqky17/tgpJ86OhP0hK0/AHbEviCP0vpHUyxedcx+pLZadmfAFLZHMd7hth56KQfsoYaTHUceG3XEb77Pz/ll7/ajnRMbGmSTtns2P0jLt34ASwriOu6GIZB0Irgug6WGWQo2UHAinCi4zWqKxdhGCb1tSsJWGGikWpCwTjBQJjGWoHruvT2wMM/6mPR/BDz5qn0m1rwhSg8LxZ8b8NQvF1xo6AHmlMpNXtYL9lsYW6BECOXcut0qol4XEUtVVWpVA7V1ROf5Zu6aSP29w4BgcKMtmwO69dvGxGEkM3atPeOFvubvnoP3o9ygKc+8dmyn2e7kljAIuwbNiURciKVmWeISy65RG7dunWmD+OM0zuY4oVdRxlITu+AkwCCAYNYOEgm59BYE2Plgro5c2NIqdwJtp3XjayKm89k1HPv4rpKzF137KIlnd29vLh5B30DQwwOJsnZyoLsH2wjm01QV7OM1ctvIByqxjBMkC79AydIpXvp7jtEItVDOtNPc8P5VFUuxDQCmGYAIUwCVgjDtBAItQ6BYQhAEAqFMIyJ21uuq77/dN2uxQLv7UEUPy/1eqnF2+iMaKT6OjA6T2LZaSxTYjbXk2qcR8pOYQoHLMi5WUwzh2FJDMvGtFxu+uY/EJU2hqcZkIwv+oaAppo412xYMj0na44hhHhZSnlJqdemzcIXQhwGBlG/h138gUIIAXwZeBuQBD4kpXxluj7/bKK1a4Ch1NhiP1nLR2Orvi/BgEHfUPq0xiq7rhJnLcJamIsFOpcrCPREBU2Ln7bStSDq50IUhF7KgvDrRkN/1pGjGQxjITWVBpbZi21nQAoiwSp6+g6RTg9wtO0VFjZvIBapBUyCwShv7H+CZKqHilgji5ovwjDDZO0U8UiMUCCGYQZQl7wAYSCEiWGYap0cvyE63Uxn4zE+8yCqkw85MIBaGPua+/c1/1H4x3X4wz0fRgDjdS5cCQMpNZA7V4yZM8V0u3TeIqXsKvPaLcDK/LIJ+H/5R5882l9/rHMAd4ybUYu9161r5tePJfoSQAikBNMwsF2XgGUOxyo7zkhR9j7XFnU2W7CyvQLrOGrRr+dyhe20wOqlWICLxUc/LyVIQkyvWDlOA/FIPQAVsYWez3bJ5pJI6eJKG+m6ZLNJkuk+EIJAIML8ynW0LNxEOFSJEAaGKP5VyiBKp1g++5no+JNUS84GO6vCOcXEe7uGAEMIPwa/BGfSh/924JtS+ZBeFEJUCyGapZQnzuAxzFp6B1PsPHiSwVR23AE6LSuHqeRXTb9N2gzjWBEcI4jzVAAnXJnfQu+o+HE0P5jyNzgzTLdVaprgOIXzU/BpG4RDFUgpcd0s2VyK4x07SCS7cVyHBY0bqK1ZTDAQQ4jJK/fZJvaWpZZwWA3cVlaqKmF1dcrvX1UFhzq6ae/tIxQ0SWRypIqiyoq55av3IIB/WXMvljs592YqY3PkZD+pTA4p8VMt5JlOwZfAU0IICdwrpbyv6PUFwDHP/635dSMEXwjxUeCjAC0t50YO797BFK/sO05/IothQNAyyeTG7+8/vOzvIVTiAnYo9Hu9RtVZJjKnH4EQAldKHCfHYKILKW0OH32enr7DrF5+E00N52GeguDPJN6BX9NUY6mBgJr9qwU7ElHPYzE1YKvXRaNqu2Bw8o3WvpNZQiET15W4Y3Vh8xQuY0HIVSFD+l23fPWesm5MV4KJJJuz6exLUl8V8VMt5JlOwb9KStkmhJgHPC2E2C2lfGbcdxWRbyjuAzVoO43HNyvRkQrprAPSRWKQtSfm3P3Dg78DwJcqPgULNo6O3SvuQXsbAp88JqNPlMR2MthOlmw2SSLZSf/ACRKJDgwjyIqlb6H95Bv09h4iFqnGMALEYw15l87I/egeiRj+bSSqp1X61ivVu9PFTUpF5lhWQYC1aOvHUEiJtv4/EiksoVBh2zOVjTMaDmA7LgPJDBMJFnnqE5/lpq/eA0DEHhg+c5qx3JhSQiRkYZkGQ6ks82pU3otz3c0zbT+1lLIt/9ghhHgYuAzwCn4bsMjz/8L8unManS7BMgVJCW7OQcrCpKliLnrku8PPc0AAJib23vWjRP9cbwVcpFSNrJQutp0llelnaKgTKxBiMHGSgaHjHG59kSsu/m3isQYWLbiEoBWh7eQO2k++zqL5lxIOVRIKVWKZFoZh4koXpIthWPnf08WVDlUVEerra0cIthZewxhpcQeDar1pFtwmxZa23tayTm+8/lRZWF/JnlQX4YDJQCIz/huAJz72pzjPBYnnehHAl9Z8A7JZ/vDgR8cZwJUqJYOEnKN+Wz/VwjQJvhAiBhhSysH885uAvy7a7BHgd4UQ30UN1vb7/nuVLkGgpoGDZ8CyxLYXPfJdGlsPD1s5Dyz9G6Juavgu/8M3Pjj8XikM3Lw5+GrVdSok004ScQZp/+D7WLO4inl1EQKBkRN7vAOq3ggb72CsXvQgresW/tcDtN6oG70/PUjr/b948LaYUoO53v/Her14u3Khh23HuxlMJBFS4Lg5bCeHEIKsnWAgcYL9h39By4LLuGbT7xEJ1xAwQ4TD1QhMYuE6mhrPp7t3H/sOP0NlRSM1VYuJRetxnCxSOjhOjnDIJGdnCEcsbrjhFloWFixtr7hrsZ/Nwn2q1FREaK6rYOehk2NOrvKSyyiJqsh2qxVCjJs7WqDcOtmcg2UaBEzVRfJTLUyfhd8IPJzvtlrAd6SUTwghfgdASvk14DFUSOZ+VFjmb07TZ89pouEAJ7oHCQZMco6jBgnzd4NlgOMRtXl5sX+55gZ21L6FpFVDTo60WHScsrQsDMdGYnBh/88AEI6NCARYdlHVnOjWaqH2NgzeRsQbZz/eoqOKiiOLXBee/tkxdu/bTSqTo7JiPlUVCxhKdNHe+TrZbALTCCCEQTzWiGUFMY0gATOs3BICKqINxKMNHDj6LMfbt3O8fTuWFc4LvotlmVTGo1RWxfmTj3yQt93cONOndsYYSGYIWiamYQMSZxzvZTalBLrS7sjPo2VCraFpCNJZm1DAoirvw/dTLUyT4EspDwIbSqz/mue5BD45HZ93NrGwvpLD7X0EAwZSQiCfAC0cNBGoizadU/4Z7eY5GV5G1lBT04Us+G68k1JW/+pplr72Cjg2UoIhXQRgtCwk0tMJFbN/QNw7oed0EquoZ9feIWy3hlS6j2x2kLb27bS1b6NlwWUIYSGli2GYhAIxhKl8J4YQRMI1mIbFUKKDXC7JgqaNdPXsJ5NV9QktyyIWDXHrzW86q/LpnCrJdA7HcTGEwHElpqEey5EZUsnWYrle+szqEa/p610TCpg4roMrBaGgaigaqqNIqbLILmuumROGzunET60ww9RURGisidE3lEYIFRZY29tF/f7dhAcGSFdWcnjRcgZq5yEBAwhI7f+UGJ5L/nHP4NWeq28kmEoyf/8bw37OZF0DNSuXk3vkUbj9VqzFs1/0TydSwtGjcKJtKbe99S7+54ePsfmV/2Fh84UAxGNqslAgGEJKiSlMTDMAGMNutVAwBhJCoThVFQtoajifnr7DhEKV1FUvJJE6wluuuYSG+ppzXuxB9Wj7EwamKRDCwB1n8DabVu6bnb/2DjqfeGF4vRb7pz/5WWJBC9uV6jcyTJqqokRCAYIBk/VLz93eVCl8wZ8FrFxQx55jXcTCQezDR5i/5QXsUAi3pppQKsXK7Vs4cuEmOhYuobH1MAEnPfxeqbvERaNXFV0dGKbJUGMzhhDkCCCkgczZiHgMZ/OWc1bwHUcVOTl0SP0fjcI7rqrhgguu5D0fepSAZZFIdVFbvQSAUCAGqEGA4kgREBiGgStdQqE4UrpkMgM0zjuPeKyG+voMS1qaaDveiY/q0fYOpEhnbBypLH1DKGdNJBQgGg7QO5hGSontSuyMstRNy+HVFb8GKbWfxz/xWYSA9UvmMZTOcrI3QShgUh0LY5qG774pgy/4s4CaishwharQob3IaAQicaxAgJrqONn+EJnD+9jx67/BBT/8b6y0FnyJHTCHxd4yDSxTELJMFhzdjxsOY7guhEJ0BZcjpUV921HMtWtwO7tn7PvOFLmcKlbe0aEGmysrYe3aQq76tWuWcutNV3PsWJB43xDzGhYzv6key6ygp9cBoaJ4hDAoONgkwjSRrkPAimI7aSSSyngt8XiQteetYnAoyYL5DTP4zWcPNRUR1i9rZH9bNx19SbK2TcgyMC2T2niEeDSIIQQd/Sru3rUNMCTCdHFShRqQ8XCA9csaWdxYDYzMuum7b8rjC/4soaYiQk1FhJThIJY0056uZcgO0RxqJ9RQQ+BEB7KlnqPv/02OvREldaQS6QoCwSjk9b82HqEiGiCVdYgMDSBqa2hNXEhD+hg2ISqDOWQiCckkRkPdzH7hM0g6DYcPq9ztQ0NqBuiSJWoGaDHvuO3XeOjhrSxfZhEJtZBKZ0ilc7zp8os4eCBEOmPoHBVIJIYh8g6eLJUV1aTTaSzLoK62imVL4wSDCfoHEtz1rpvO6HeezdRURLh0zcIR67yCXVcVxRCCoGVyRJoIAeEgKIemYtOaFhpqIiP26Qv8+PiCP8swGuqQiQSmWQUy7zxIJonMb2T9UrVUOvCL4+A4DtIz4LVqfhMLm9RF37d9EYnuXrLxeZjJQ4QMiDn9ELCQQwms666dgW93ZhkchLY2lde9txcaGqClRRUML0dtzULeeoPF9p3P0tExQG1tFRvXryMUrCEeH530TAQDYJo0N1fQMG89yfRJfv+T/z9e32XT0bWTeZW13PWum3z//TgUC/bOQycZGpIIrHyaZpOC4JsEDV/cTwVf8GcZ5qZLyT3yKNFIL31EkYnEKIHWUStCmCP8yQFRuAni11xB4EePUV0do7ZuAX3HIDZwEnP9CgJvveGs9t/39BTcNp2d0NgIy5bB4sXjR/w4Dpx/XhNvuuJdHD2q1nV2FgqQjxxjNAmFI9i5HOl0glgswluuvYRLLmpm5TJYs+ba0/Dtzg0W1lfy8sk+XFellJ5XEx3x+smTKj+Pz+TwBX+WYS1ugdtvJfriVmTbIHJBjMB1144QaMMYGVOuGRxU4qb3k7jpNsSzR5Fpibl0KVU3XE1gydkp9FIqEejrU+ekuxvq61U92eXL1WSmiVJTM1LYXbcwA7YYQwQIhQJceeUFGAasXaNcRz5To6YiwrwKA+kKhHAxhJYqNWDV0QGrVs3c8c1VfMGfhViLW7AWtxDcDe4isIoKYOuJRzDSvVAsNOmahVhXLcRpvBLrJASWnNbDnhFcV7ltEgl1XpJJNQg7b55y30Sj4+9Do0v6RSLquRb6XK684Eup0h6EQoXKfa2tZ182zJlA2iEsU/0eSxvmjXito2OGDmqO4wv+acY+chRn8xbczm6MhjrMTZdOyp2SSBSiSLyUyiM/MDBym6GhkY9nE7atBmLtfIZdPfkyFoOmJlWGb7L09o78X5/bbJbhFBTF6JQN+vN1A3Mqn+8zkmy2kEmhuEB68bXuMzFO8xzGcxv7yFFyjzyKTCQQ9bXIRILcI49iHzk64X0UX+hQWuxhtLDncsryTCTOntwsmQzs3g379yux142hlMoVs2bNqYvt4OBIy1yf33S6IPjF51xX19I9AH2eff/y1MnmyzmHQqMFvtR94TM+voV/GnE2b0HEYwitSvnHiU56CgSUwI3ab5lMmKnU6HXxuNpHZeVEj3p2MjSkXCWahgY1mJpIqBTAixdPT6NWW1t4rt1lumEptX8pC5Z/KFRwC42T38tnAujpJsGgGpPxUupa9xkfX/BPA5kXNmM/8TTuocMQizGwbAODdStZGu+AaHTCk55iMTUIWUy5Mn/p9OhtYzF1s5RyC80FenvVYKxm0SI4dkyJPcCKFdOTz12Lu9cy957f+vryET6GUWgU/AHb6UFnajUMJfjFPnt77GJZPmXwBX+aybywmdwD34VIvnRQJkNo1za6VteRiVgE0/0TnvQUjZYW/HJ4BV/fEKH85MS5JPhSKkHv6VH/h0Iq2qatTYk9KIs+Mo2h2FqovY2HlMotBsryLzUQK6V6jxb8gYGzx302kyQShXMfDKrf3ovOduoPjk8O34c/zdhPPA2RMK/FLudEzfkgDEIii9t+kiM9FcihBOamSye0L69/uhht5Y/4bI/Vo32cuus7F24M11Vumz17lNjH4yr0LhJRKREyGZg/X/npp1PsYfSALajzOzionusCI6XQ67XbzB+wnTrJZMGdKcRon73r+m6dU8G38KcZ2dML1VUI4GBwLY31bRj9fSzo20VHaB285TasxYvG3Q8URDqdLi1wxYLv9e3PpQgd21ZZK/UgXW2t8tH398PevYV18+aV38dUyWZH94JctzBYqKtNFSNEIVwzFlMNhz9gO3WSyUKUjuOMdldKqdbF4zNzfHMVX/CnGVFbg0wkWBfbya94M1vDb2aT8xNijTGsq66k1YWVk9xnIjFS8L0DhV7mmuBnMoWMlaAmjdXUqJt9zx61LhpVfvsz4SYpzq2j5zt4q1IVo+PwoXCM4fDpPc5zAV1hTf8GxT571y09ZuUzNr5LZ5qxbr4RUmlkIsFqdyfZnKAvE8W6+Ubmz1eiXC7KphzF3Vl9IxTjvSmkVBanbU+/+2OqJJMqtFKL/aJFyk0Tj6v1OqXBihVq8tTpFnvtOvBa+EIU/MThsBL1UoLvrbin3T8+UyeXK4S8al9+MZMZ3/JR+Bb+NBO6YhOgfPkNPQfYV3s+uy75TW66YiEh4PhxNfC4ZMnE9mcYo32Vtl0+RNBLLKYaizPd7S032ay/H054qhgvXaqsZ9dVPnrt0lmy5Mxaydp/X3xO02kl+NqVU+qYhFDfIRz2BWg68Qp+uegoPyJq8viCfxoIXbFpWPivzcKvfgWvvw7nn6/80B0do4tqlyMeHz3pRAtjMd40C/q9HR1nVvD1ZDMRjw1PNmt/+DkGLo5gNDRg9HSw4OhzGN2dyPo6jq24kkRMpa9csAAqKs7csWrKCbU3UgRKW/iGoRoEHQ3l+++njpTqGnecse8Rf7bt5PFdOqeZYBAWLoT2duU60BN7jh+f2PtLhVOW6+JqwS+erBUKjd72dOGdbNaRqWa/s4LeQCOhA6+xLHSURS9/HzM1SF/lIvZ11dD781epzhxnzZqZEXuNd8KVJpMZ2Wsq59KxrMLv5Av+1EmnleBrC7+c6M/m8anZii/4Z4DVq9Xjc8+px9pa5e8dp5wnUDo0s5yFr9E+/5m4IdzO7uGEMlnHojKYZEVdD/MT+5BbVGOQCVfTnakiFhOsrO2iZs8L4+z19KHPa6lQSh1jrynl0tETg3QOndk2XjIXSSaV29J11dhJOQNHz2r2mThTFnwhxCIhxM+FELuEEK8LIf6gxDbXCiH6hRDb8sufT/Vz5xoXX6wE5PhxFXIIE8v4pwXHa7WXE3w9+3YmI3SMhrrhO3FRrJvGcD8ipSps6cYgbOZYUXGCBdEeRGziM49PB944ey/FSdlKbaNf9+bY8SddTZ1kshCJJkR5YU+lRrsxfcZmOix8G/i0lHItcDnwSSHE2hLb/UpKuTG//PU0fO6corpaWYFvvKEu5oqK0pN9yuGN1NFpe0uRzaobxDAKj2cSc9OlyKGEKtwi5XABF3PTpSMag2FhnOFyi+V+g4GBgojrsMtyPnwonfPI59RIJEZGspULv8zl/NDMyTJlOZBSnpBSvpJ/Pgi8ASyY6n7PRi67TD1u3apmjEIhfcB4eK2csQaz9A2gB2rPtF/cWtxC4PZbEbEYsqsHEYsRuP1WrMUtYzYGM0UqVdpV09s7evp+qbEQ/TvoRtxnerCsQu+p3Ixaf7bt5JnWKB0hxBLgQmBziZevEEJsB44D/0dK+XqZfXwU+ChAS8vZVZ3JNGHlSti3T7kSIhHl1ik1YOileGp5uTh+KQuWpo7umYmZiLqAS6n13H7riJBNq6ia10xQqpi54yiLPp0uCH45H75+3R+wnT4Mo+AqK+fDLzUD12dspk3whRBx4PvAp6SUxQFTrwCLpZRDQoi3AT+kzIRTKeV9wH0Al1xyyQSGNU+NqRYmOVVaWuDAAdiyBa69Von/wMDY6YtjsZH+eMcp76rxVm3S751NlGsMZgItJOUs83BYCYp25ZSy8L2VsGbbuZ6reMdPLKu8geNb+JNnWjy8QogASuy/LaX8QfHrUsoBKeVQ/vljQEAIUT8dn30qTEdhkqmwSYXoc/Cgsg7HC9EsleOlXPUl3TBoy+dM+/DnEjr+vvgcJZNKbIpdOaV8+OFwwa3jD9hOHcdRDbFujE2zfDSblH6a5MkyHVE6Avg68IaU8otltmnKb4cQ4rL8557x0Az7yFEy3/s+mX/7Gu7xE7jZHL3ZCojGEPEYzuYtZ+Q4olGoq1MpBBbkRzvGquDjFXxduHy82Yd+jPL4lBs/8Q7kClFw5ZQ659Ho3MhEOlfQSdP0HAjDGDsSZ7JpSs51psOl8ybgbmCnEGJbft2fAi0AUsqvAe8CPi6EsIEUcJeUE4lCnzqD7/+t0SEULYsQUpLbd4jORYvojjUTDySZ13HwTBwSABdcAD//OWzbpnz4x46pfDKl0JalN3tgOcEfHFSNiZ/XZXykLO1395Y61KkT9PNiIhH1+/junOlBC342O/5sdF203mfiTFnwpZTPAmN2ZqWU/wb821Q/a7KUEntbBMi29RNfWo8VECzveYne2svoHbAYCq+jqUtVNzrdGAasWwevvaYGVoeG1KGONSs2kVDiUjwhyEsqpbbJZguTgXxGo82NUgO2oMJo29tHWviliEbVb+YP2E4PWvC9mTLHcun4PdnJcXZ7eEsERyetSo5Hl7PPXU4PdchEknq7neXOHiovXEZXl8rYeCYSMzU2KsHYs0dd2N5UwaVIJEZ2dUuRzRYic/xc4eXRlmGxmGv3gR5ELxb8YovTstRv4Z/r6SGTKfjtS6VF9uK6qjfmT76aOOdM8rQvnvdfADQm9nDX0S/QceG76Wm36DFrqTYWsOD2FbQsacJ1lW/9xAm1LFp0ervrF10Ezz6rLJXKSjVYpaM+ikkk1M3guuWLcTiOOl5dMcqnNOUmXHlLHepcLmP1uvTgoj84Pn2EQgXf/ETSiGQyfkqLiXLOXKbndz4FwMnYar583v18r+dGglGThrtuIL3pOg6kWzhwQN3kS5aoXOyGoXzru3efvnhfXa91cFC5Y8pZ+bFYoQvruuWtGp1/BEpHlfgohoZKN5reVMnaleDdrtjC963L6Uc3sIYxtuC7ruqp+aGZE+fsFnyPafbWrm/zR298kA+88QcE3CQOJg/Zd/IfjzSzebNKW5zLwf79SuBtW9VTXbZMvf/wYbW+3CSQqbB6tRLp3t7yBVJ0L6O3t5AzpxjdDR4r4senQKkJb95Sh7ruwFhROJblW5fTTTBYaEjHu98yGX/y1WQ461w6mRc2Yz/xNLKnF2P5Utzde0eYYfWhFH/6t2qmzbe+pWLhDx+Gf/1X5Up517vU5ocPq+110exUCo4cUZOmAgHVC5iucDwhVNTOq6+qmbfhsCoO4kWLkB6kKiX4OhWAP5A1NrpBLTfQqgdyyxWa0ei0yP6A7fSQyRR6r7qnOpYPX8/E9S38iXNWCX7mhc1kv/ivw//L9pMABP/o94YLkni5+271uHcvPPSQsib++7/VulWr4MIL1aSo48dV5M6aNcr10tamZshOZ73V2tpCZIh223j9wto9o633cpELjqNugHJRPD4FP31xg+0tdZjJjMyjUwrTVOMufg6d6UFH6Ohe7ljlDb2cjl732cpZ5dLRYp8lyP6Ki0gb4RHry7FqFfzZn8Gf/ik0Nal1e/fCgw/CD3+oBMIbvbNqlYqw0cW2T5yYWG778Vi/Xln3R4+Onn2rG5VUqlBvdSz8AdvyjDfhSp/r8Sz8UEgt/sSr6SGZLMyyzeUmdp37TI6z0g782sp/xraUDySWbuPCgedYd1KJ9FgEAvCxj6nnL70ETzyhLL0nnlDrli1TUTV6wG/FCiUePT2qIaivn1oMfzCoPmPnTtWLWLBgtOBkMhMvjehTGtsufX6KSx2WykrqnfkZCPjROdPJ0FCh0pVtq8Z0LENK58v3mThnpeC/+cB/8NNVvw9CkAgv4Nnwu3n2a8oSW79eWejLl48dxXLZZWoZGID771eCfvCgWiIRuOSSgn+xpUVZh11damluPnW/7pIlyq3T1qZ6DjqNsiadLj/dXA/agj/pajzKTbjyDuRO1ML3mR70zFot+mPl0dH4gj85zkrB3+C+wobdHwLgjdD5PL70DwETx1GpDLZtU9stWqQiZNasUekISlFZCZ/6lHr+6KMql30qpQqTC6Hq1erY+aYmJfxTieEXAtauhZMn1Sxcr+BHIsrqtKzy4Wp6QNK3PEujB/iKG8RSpQ7Hy9MSDvsDttNNIKAs/YkIuXb5pNPjp2HwUZyVgu/lvMzrnLf7I1Q89AA7d8KPflS4kY8dU8tPfqKEfflyOO88FSFTatDz1lvVcuQIfPe76kLT+4hElLtn4UIlAv39aj2o/U3GEqyqUq6d3buVP1+XBdCTrgKB0he31zryKY122xSfv1KlDku5dLyRI5WVY6e19pk8oVChcPlEjBbXLaQlGSsFho/irBL8ioceYPDO95dcD8qds369uki2bYMnnyxYygMDKizy1VeV2C9cqFw/5503usD14sXwx3+sbvwHH1Sx+6mUKlKurf4LL1TCnU4XJlMtX15+Fm0xa9eq9+3cWYgEsu1CYY5y2LafMrYUuv5BZ2sMs6oCO7JoOC+/feQoJ39+iGxPisyOfAWuxpbhxtWLtwGorPSjoaYL77iIjlDzzmIey4U5NKTuM1/wx+esu1y1uI+FYShr/KKLlGXw8ssqvYHu7tu2isM/fBiefloJfkuLipVfsqRgeVgWvO996vkrr8Djj6v3aqs/GlXvmT9fCcVkYvgtSzUaL76o6uCuXTt+11V3cf2okZHo+gciHkNUNlHldJF7ZAfcfisAuUceJW2sJlwTRiY6yD3yKM7NtwGLSvTMCn6egfQQvYNBair8mVdTRd97Qqh70lvasNz1rgVfz7YtNsx8RnPWCX4pvJOxRG0N1s03Dsflh0Jw5ZVqGRxUwr1ly8h4995etWzfrrZvalICfN55hRhs3YD096tY/pMn1YX44ovq4p0/X40XVFRMPIa/pQV27VLL8uXKktG+5omkVvBROJu3IOIxZDSGGBTUVLqIVKH+gYjHEE6YmmAfIqAGXZytrwCLcGWGnYf6SKZzqpflNqBvm0DYZs+xAVYvqvdFf4okkwVjRodk6t53uWtd14ZIJPzZthPlrBf8zAubyT3wXYiEobpKVbd64LsAoyZjVVTAm98M11wDnZ1q0HTbtpG55TMZ5cM/cgSeekq5bVpalDWu/fe/8ztq26eeUuGdjgOtrWqprFRC39KiLvLq6kLsfykuv1ztZ+fOguUz+gZQVmcul8V1oe14H2vWNE/pvJ1NuJ3diHoVflMVSGAZLjIaxe1UNXhEfS2xVIoKqzCiKzt6sIMOJwf6yHYP4riSTM5Gegq1OWaS7oEUm984Rl1lFFeCIUCixCgaDrCwvtJvDCZAIqGEXtcR9rp0vHmNitF1nEskxvUpwVkv+PYTT0MkjIzFOMJSorEk9RzGfuLpkrNvQV1g8+bBddepBuDYMeWn37lT+fo1jlOIw9++XfkQGxvh/PNVD+Cmm9Ry5Aj8z/8oC31gAF5/XQ3INjSoAd2eHvW+UpFC1dXguL386JFOOrpOUFmxEstyQEYJWFV4SxGYJjhIHn9qM83zN7B2zdLROzwHMRrqkIkERizGvEj+B0wmMRrUCZeJBPNjnq5WMkmmsorB3gyhXAorlcEQgqztIjytbSI3gCUlEoHjJqmMBhlIKrO0vipCNuew51iX3wMYh127D/HIj4/Q25MhlUlgcj65XI7K+CIsK4oyaAKA7roqAyeVymCakmDQBGaHA3/wy1+BF7cUpmlffikVf/DJmT6sYc56wZc9vVCtYudaVREu9kRXQzpN8KdKUJubleCWcoWYpvK5L1kCV12l4vB1Bk3vRB2d0+PwYSXwTz+tJvcsXqxi9j/9adVFffhh5dJxHBVv396uQjcbG9WA8tKlI0P9du0+xJO/+B4BcQXxWDPSdRgazBIMBghYI/1B6YyNlJKt2w7y5Ace5JMfvZNEMkXb8U4WzG/glpuuPCcbAXPTpeQeeVT9E41CMokcSmBddy3AqNcyvQMcWXc+TpeLNGxytkTZ7SCFFnwJhsR21fOhVJahdJaAaSCl5FhHjnDQJBiw2N/WzaVrFp7Bbzx72bX7EN/8zqO8sm0PAHW1VRxtO0ld9QVk0g6ZrMvq5StJpweoiLcgpSSdcQhYAbTQu24OYVjksjmkyDJ0PME3v/M66ey+Gb3OB7/8FXj2BWWDCQMcG559gUGYNaJ/1gu+qK0Ztu6u5pcAJJMu7bEV9AWVaPf1qYFRTXW18rnPmzeyEQiHleW+dq1ytRw+rCZI7d+vfPfa+JNSiXtPj/L979ihfP8NDUrU3/525Zf/yU9UI5FIqIbkyBH12eefrxqJXz77LJ/723vp7u7nvJXVnLdqGTlHYhgWwhj900nXRRgCywzT0zfI5//xP6muqiASDmGYBo8//QKf+sRd3HbLVafpbM9OrMUtcPutOJu34HZ2YzTUYV13LdbiFnoHU5y86HIC27YR6WjDralh19Lz6RQRVWjGGuk/c5HDz7zI/J+sXVifztkIASd7E/QOps55K3/X7kN84Uvf5PCRE8RiEVKpDC9seQ3TNAgYQziuSTAQQwiTXC6JwEDiYpQa6JISCWQzWVKZJK+/vI1Y/CSZbI4vfOmbNDfWkc3aZ6wBcF1IvLSL9uh6bCvM6uSrQF70X9wCf3BaP37CnPWCb918I7kHvqtuyEgEUikiqTSr7lhO6IrCdslkYXarbgR27Sq8XtwIVFXBhg1KwDs7Ve6btjYl3MXi7zhq/0ePqt7BU08p63/NGjWQ+6tfqc91HOjuhmeegWefS3HoiCSTDuO4vew//EtWLLmOUCgOCKTromTGezMIFbkgtSUkyWUj1FQtwxAGoUCEb377NULBFSxd3IRhqO8y2ce5OMHFWtwyHIap6R1MsedYF4GGRgK33MLR/iSd/Ulu+Mo9ZI0oP2r5I27a+i+EnQGe+sRnuemr93Ci5XMMxFYQTR3npq/ew1Of+GzZz7QdSGVsKmNhWrsG5pTg64yVjlN49D4fa5332veWKfzx4yfAXc3ihauwTJe2E8cJWIO4rsux468RCVdTW72YTHaQPYd+xuKWK5GuVNayBzW4K5HSwXVVr9Z2HF574xB79h8lFAoyv6mBW268nP6BIe69/2E+9uF3jCn6jjMyj49tq7GEZFI96qW/X7lmU6lC1NzwOVv0J3nrHlYfelWtFMasqrR+1gu+9tOPiNK549dH+e+jUVi5Ui2aiTYCDQ3KJbNundq+q0u5ag4dUheI/r31DZDLFSJ/du5U1v/8+WrAqrVV9Q5y2SALm6/g7W+9kO6+I2x++d852bWb+fPWEQhEyWQHiYSrMIYtfYkjbQYH2mk9/ornmwlc2yASq0G6JsgwL209RmW8acQEF+8jFCa/6KV43aliGKMbkIk2NtOdO6W1a4CAZRIMqG7cQCrLDV+5BxOQQq2LOAkM4KavqvWOq8Rn+eB2zPz6MUXfhXDAJJk+tZSOWjDLiaued+EVK+9SLNzFz71L8XV6qnivE/2bCQGDAzHCgWYCAZP+waOk0i4NdaswhEVH124y2SG6+44QObGdoUQ7jpNBCgvDMFE+fIXjZpHSQroOWTuFlDbtHa8jpSSbtZES2k/20NNjUV21HGSOR358BNylw2KuH/XcFtdVA7+pVCHMU9fV1ehwbH3NRqPK8Kuvz6dT+cv/S9QeANMjq9Id+f8MM3uO5DQSumJT2QHasSjVCCQSyppvby80Al6qq1XSs5UrVSjl4KAK0Tx6VIl/Lle4iHT0gb7IVAROllS6HyFCCMMiYIWZV7uKm6/9Cw4c/iXtXW8wr24Vqcwgla49LPgSyGYTdPXup6ZyEauW3UDAClMRm0dlRfOIhqHzpMGjj3ZRVRUnFCoMdp2qmOrv430sjqzQIqJf8zYi5Sjeprjh8T4v1RjpG9T7qAXItjN09QdxcJHSBQSZdA0n5v0+rhEkbUZIGBWYeb+xiTLeEkKdr9psG3vil5IzQ/S0VuNKAxyBlALXEUgE0jXAFXQYJpFQkMM7xi7MPV2Nmff6msiiG1Vdn7e4MS7+33ucY/3Wev/6tXhFlvbOHnKuQTpjUFWxABA4Ti5/jZokU720d71BJjPE4GA7kXA1oWBlYbwWkNLFJUcy3Uci2Uki2UM2lwQE8+pWEQ5XUVlZQV9fA/FIIwFD0t+X5vDhgtWeyxWEvtioCAQKeZIqKlR+pXnzlGE31oz5wU3nKx++YyvLXrrqxrz80un5YaeBaRF8IcTNwJdRP8t/SCnvKXo9BHwTuBjoBt4jpTw8HZ99ponFRjcCQ0OqF1CuEaisVO4fUFZ9W5vaRuf+hnyEjZMhnc1imTFc10W6Do7MIiVYgTDLFl9DOjuIABwnSyrTR4XVBEhsO0Vn1z6klFxxye8QDVdjGCZG2fnpgv5+gWk6jLibzgkc1KVfOWr1ybqx053O69xNR3w+2+fdPrxusLVMIiaNADtkkc3P7fBavlpAtdiUcp/p/4u381rP4zUW2or39hi8lr1evA2zds3o95V6rdR23hKc3v+Fu5H5DatwXJeGKheJQCDIORlc18a2U0gpOdHxOp3de9l/5Fma562lvnY5phVEILGdHD29h6msbOJEx+v09B3CNCwq481UVjSzoHED0VgtoWAEgzCOXYPt5oiEq4YLAwWDajwuElGu1fp6JebV1VMrC1rxB59kEDxROtasi9IRcoqJ3IUQJrAXuBFoBbYAvyGl3OXZ5hPABVLK3xFC3AW8Q0r5nvH2fckll8itW7dO6fhmAikLPYGOjkKMsJ4GrieUSKmsfi3+2VwWJd4ZstkUpmlhiADCUHezEAbSdUllB2g/uRMpYVnLVeTsNPsP/ZymeWuprlyIYQUwRQAh8g7FcTDPuZlaXp+qBGwMO4uJJOjYCBwM6dAyuJ0bTz4wvFWpMynze/OeQQdGuHkCpqA6HuailfPL+vFL3YbaZZPNFtwQOuY8m1WLLtTidevo17XbQr/uFXbbLgi0Fn+voI91XFPDmaBLW+I4GVKpAU507KKhYQWhYAXSdTh45FcsaL6QvoFWAAJWCNMMks0lcF0nP4blYIgMS5fUkEid4M53XsaFGxYRDM7NMajJIIR4WUp5ScnXpkHwrwD+Ukr51vz/fwIgpfw7zzZP5rd5QQhhAe1Agxznw+eq4JdCFyBvb1eNwOCgWvTgUCYD+w8exzQimEYAVzpI18F2s2TSgwSDcWUp2jnS2X76+4+y7+DPsQJxVi69nvmNawiHqxHDsjTxq1rp/fSKvu5YeN0G5RbtKrCswmMgoKwt/b92N5RLGjeZ561d/QQsA+fwUSqOHyM00I8UIIMGl7f9qKyww8iz6hV7vT6JxYnAAo4EVvDaBTfjpMNgW/mtLFTNoXPCkzoGExV9FYLZ03eYI20vETAj1NYuxTKDVMTmkbPTZHNJhhIdZLKDVFUsxDBMunr209Wzh2jU4Lo3b+DqN63jvNULCQbVNXW25z8aS/Cn46svAI55/m8Fih3mw9tIKW0hRD9QB3SVONiPAh8FaNFpIs8ChFD+wIqKgjtISiX6uq7uth3HgRDCMAlYMaR0CQbCZO00XT0HyDkpjh7fSiLRTThcRVPTegaHOpS4SsOTaGd6TJhiF0MgoBbLUjeO9nVaVmkhLrYUS/nex2oIyrkxvC6OcrMwxzIlKjJgHTlE7MB2XCtAwB1SYfU5B0cIrKI3S6Cvpo6K3u5RlrwW+2+Ff4eupVeMeB9+beGSTFTspZS4EiriTYRDNQwl2hkcPEEsUkfSCFIZbyQSrqG2ajG2k6WzZx/pdB/1tctZvHA9dXUut928GsNQhtZE0I2CNji8jcSp9gx04j4dEmxuunRUxNiZYta1dVLK+4D7QFn4M3w4pxUhlH9/40a1PPA/P2DP3gFqqhbjIomGq6mINeJK5bc83LqZpYuupLZ6GQubNlBbs5ShoZNUxpvo7j9GIBUkGqklGIgTsIIMx4iNQ7nardr3qrNvTrVYtFe8tTWv/al6kCwaLWQ9PF1pnk07Su5QN32hJoTpckX7U5iuqxoPK4DMZREUrHoHeOE3fpubvnpP2X3enf4a/7OnhxhHqaOXWtJsvuN9w72XQACqK0JUREJc4pmEVcoHX6rKVjmxGSuFsHbXFIdRlltyuZE1Zb2uoLEysNp2wdWUyahB0UxG9Vx1pIvXr6/Gq6Bwhoe/uT7y/BjVAMl0LwJJy/yLyOZS9A8dp7ZmGYZhYAZC4IJpBrGsMAubNuA4ajDXkR3Mm7dquKbEZZcV/PPe71pumQy6cShuJAIBcI56EvfV16rULo88CrffOiOiPx2C3wYs8vy/ML+u1DateZdOFWrw1sfDimXz2bZjDz19h0esNwyL6spFXHHhb+FKSWVFE9FoLeFgnGC18glHgpUcaXuR3v6jNDeuoyI6j1isnki4BsMwEBgIYSCEyEfsFBqDM+XC9wrQdBWe1q4e7fbRg3HRqBqEq6pSA+2G4RWvAEPZXrLBKKabIRuPE+3vRyIw7ewop5gOv/S6bvR6L+9yv1f4rsDeBpeKrg6aD+whOjRIrrqawKZLicUWlhVdrzBrP7v+fzy8oZql/PYTaUD1pEE9PpBIqMU7VuD1908N1awWPLsOEnAdm0SqF9fN8eQv/opfu+ELxKL1xBAEAiFsO83J7j3Mq19FJFSpwjaFiURgWWEq403APLo7g2RSylBpa1ORNldeWahDPJl0yq5bvpHQYyelsJ9tRRqrEU6Y5ZzAyFdEcjZvmbOCvwVYKYRYihL2u4D3Fm3zCPBB4AXgXcDPxvPfn5sI6uqqGBrKErCq6ew+DEB9zXJqa5ZTXd2CaQaJhCoJBCIIYeUHXCXRWC31dStpO7mdY8dfBlQkj2kWwg4Gh06SSHbSWL+AmpoFWKKKeEUj119zLZFI5XComp4roAcMiwXDe7NPz41/6mjL01vXYELE3j48Evtc821QnMBOSsKZDj5+6DMICq6bNuCnzZ/mAyf+aUQvQCI4WLFx+O0uYB90CB3uodOajxsJEEw61Dzbxm53IUZDw6hD0g2iV6i9ceJSlndveZdgsGD5ZzKFHE59fYVJQ5nMSPGebopDP70hnYGAycmODlJpFaNsOzYSm2wuhXRssrkk8VgjjptlYfNFRKM1+Vh8QTRcg+u6dHXvI5Hsoio+n5aFl2IJY3hGrspuZAzPdu/vVyJ/4oQKj25uhquvVobBRDGMyZe0lBIGfrKbXE0DLpmCseBJ3HemmbLg533yvws8ibov7pdSvi6E+Gtgq5TyEeDrwLeEEPuBHlSj4FNENpvjzW+6iDd2HyeXrae3/xi27RAIRKitWkQ4VImdyyCEiSG0falkJ2CFiIYqOW/FLby+5xEyuQRVFfMJWGFydppYpIZYtA6kA0aOA4dfJWcn+NB738ZHPjIyPFF3y/v61CzioaHRscsaPXlFC5Oe7FMqNl8/el/XlnkwqG7AQKDwWnGYnxZ3r/h58TY++pi0FTbKwjVQvppyYicE6fDoqvc7qt9Jd/UF5E4YBPLpFVQUvySe61VvlTbdS5fR0rGToJVEhgIkjXqi8Sos4eDu2VdS8IUonI9SOI4S654e9bv094+0vk8XpSbgeWP3LUsJYSSirOZAYGT8vvf3TOUt7sNH+7GsCOl0glS6g0QqRSrVx2DiJB1de7nq0k9gmAGqKhZgGQF030qlFRFEo7UIYbD30E94eecDbDz/LhbNv5BwuMoTuKBGWfRMd51j/+RJlUplwQJ461sLKc5Px3kLz6silOhFeGudehL3nWmmxYcvpXwMeKxo3Z97nqeBO6fjs85mFsxvoH9giDdfvQ47N5+NG8N867uPYTtZotF6ouFqBu2TWGYgb/FoVDxzOFxDlTCIxxvpOPIMg0PtLF10JY31q1SGTaOfgT5BVeViGjeuoK62isSQw1M/aadloTJxvZZMVZXqBususO7WJhLKYuztVY2C9tV6Z37q/71Wqe45eMMJdXc4lVICpo/Bi/b1e6N5vPHf3u20MGlLzDvpSgtqNArxuIn19CPEj+8llusmnusm5iaH95U0IvzHqi/TY9ZT63QNS8jS5Ovs5u28WvVmLuv/udp3/rXG9OHh99fubqV9zXqyFZVEwyaxUAWGaWEaYA10Eomrc9faWpjJnUpNn6urFMU9AS3MxfMCvAPi3vfqsQjtny6OrCq1mKb6Xn196jeXUp3/oeQRhDBIJAfoGzhKOFTPya79nOzaTSYzgMQlHq5hYeMGhCe1ghDqWg8F42RzSRKpbgYTJ/nVS1/GMAJcuuFdLF18NbHwPLyzc0FdL9ms+i7aXXXokGqobrpJBVNMd8jmeIn7zjSzbtD2XOaWm67k3vsfBiASlAQDMSriEWw7gyuVeSuEmffFj/QgqygYgZQu8WgDlhXGttP09L1CVVWSyy++lp2vpwjHDnH55bXkshGcXA2ZrGDzSwOkEk3E45M/ZtNUA886gkeIQu4gPedA14b1ConO668FOBpVIp1IqMilgYGC+8HbkBTjdRnE48pvX1urngeD6v26h6LDX5NJ6N+2F5INUF2wtPdXXAbZLH944LcJukp5j8XO44Gmu6jOdvGBw39Bk6PCPV6rfjOX5gVf881Fn2UgtICcGQYRAEtACrUU8/XJn2v9fXXYqjdaSg8UeoXca2F7UydoIdZjH1rkdC8rGlWPxTNvS7mQSrmUQP1+x48XxF4IZVGHQsqIeGVHkt37WgkF4jiOw1DiOG0nXsFxbQJWhIAVQQiTULiKQpMKqi8lSaX7MM0gyWTBNVJdGSFekSAU3kEoVklisInK+KpRufK9hoIes/j+99V5W7xY+fl1lbqpMlbivpnAF/xZxNo1S/nYh9/B4089z8n2ARoaavnUJ36Db/3380g3L/jDA64jEXmHspQu6Uw/rmsTDFpU11QSCKZIpvYylEhz/uorcHIhDCOFCHZiBmy6ezLMm7dq+IatrVUzD7Uw6EFd72CentxTLnojHlc3thdvRaNkUom79lXreQmgPlcfQywGNTVKyEMhhisc6fkMfX2F3Cc6P5GuIVwK000Ty/QSP3mY+nQbi9K7qU+rGIMvnVcI5dApFTqjS3HMON2RKABBqRqC/nDziAlXT9XdRnf8vPIfrM4AytsPwZBJLKYay+pq9aijlYLBQi9LC/pYgmvbhfOZSIztk9eCHoup59Np0SaTqseSTCq3k5Tqt1uwQOWaqq4u/G5vefNl9PQOMjiUxjQznOjYRThiYZpBcjk1kCuEKJEpU4B0yeYSVIYjXHThclYtX0gwGKCqUlss/fT391NZ7fLrt62iu1vVq+juHin22lWoz2Uup1KXHzigalM0NqqstfPnT+28lErcN1P4gj/LWLtmKWvXLGX3bnVjtrRAQ/0inv5JGgSoSbej71IJICQBK4wrM9TXxjn/vGW0d/RQVRWnqirGeWtyBAK9uHIRyACm1cXgUDdNTXGuuEIJR1ubEs/jx9VNsHhxwaesrfHxKHbd6Eddhi4QUO4ib95/UGKeTiuLfGhIvae3Vx2Lt/HRVqSuHlZbq25Q7cbRnz84qATo5Mn84N3BDtJDWfok9MXW0hpbyzZu5DcPjE5+JvLO/d5gI8pLrz7U1GmRAyGSb7qKyueeBcAI1FKdPTG8Lwm49/87FT2dU47B1lWdEgnl9komx94+FiuIeih0emeWZjLq9xkcVOfYcdTvqpMK1taqzx8YgD0qBT41NQBN3PHr17Nt5/OcaB/kumvfyuqVi/m///QNaqoqCAaU+BdnytS92Ia6GpYtXsJ736vmF+mecUU8yuBQkv7BBHfdeSWrV6tkhnV1Stz371e1LLwhxtrVA4Ve0sCAet/Ro8rHX1UFl16qGq+5jC/4sxQhCjf2zTdeSiysiq1XVgTy60fO+TSEgWUKCFo0NlSxbOlq1q1dNiIX+K7dh/I3RpJIeBmDA3GSKcF77rgIUKK6eLHaY19fIeMnqBumvn5i4iFEwcUwETeRvuG8jUMmM9qfnc0WGoNEQllsbZ4A4OKBxJoaJTpLlkDgFz/G2fGQajFNk3TWJOzx2UtQya7yAiPzS8KqxpRZnHzSNI83mQ7mETAiRNwUpnRG7MsFlUahIi/uedEnX0e3WPR1Ej1tpY9Vo1UIdV61sE8l/8upkMupcYf+ftUg53LqeBob1TXivU6yWTWxEFQPpr5eNcIAb3lzEzde/84R+372he3s3HUAwzQwTYNwMDi8L9XTtAgELD7x23eRycDaNeo13TPWxX7uetdNw9d9Q4NaHEf1MlatUvfWli2qt+G1+rXw6/EKy1INmW4ALEsZGpddpkqazjV8wZ+lRKOFQuraRwtgWSECAXDd/PR0E3CUP8cFIuEg77/rDm6/bXQUiNdl1Hb8deY3LeTqN12FKZoYGFAXsqa6Wi22rW7Q7m61mKbqdUwmPG08DEOJwUTjonX0jde1lEqpmzKRUA2CdhEdParek3ulGRZ8HITBtd3foyMwn5bELoKuUlYBkO6ESCObw5dwSXorhENknTAhJ0XSCqttIpFhP46Izse57U7EI9/EyPv89eyGqodUDh77SGHijVtXR2LAIfGDF8heHsKpGR0FpLEsJeha1Gc6HYDjFHpKvb3qnEciSsB1NknvYLvrKmNBN9rLlqn3traqa3nZstKf84H3vo1773+YaCyEiYkwQ+CIEek/9DiRF90zHgvTVD2P+fPV8dfWqoa1rQ1eeWVkz0m7GtNpdf/pOhhCqO/xxBPq+1ZVKfFftKj8584mfMGfpRQLfiAfcFDwz5r5TJf54DNTUB2poq4uQNUYYWbFN4aUqijL8eNqWbVq5I1rWcpChqlZ/dOJtuAn4146/sT36QouZNCqJewkIBqky17E/NS+4Z1eVnmYl9z5PL/iU9z4FxD6exB1jVRVQvI4VDz0gLIA/06d9wE7SpPoJ3jnOwgea0GkKqjIC73G2bwFEY9xwF2GHMr3HoJpxOsHiF7XOELUx5o1OxO4rgr/7OoqRGOFQkooa2rKlwXt6FCiCMoFEo8X3DmNjdqlUxptlDz0PZVl2DQNMPVMCIVKbz217xYMFq7rlhYl2NksvPqqur694yC5nGoUhFDXvB5fSibVd33ySbVdZSVcccXsFn9f8GcpXjFTk1XU85GhiOom0DeddmOEw+XTJRQjhLrg02mVz2fv3vI3pbb69Q2grX7LUhf5dFr904V2L9WFEtQlXx0OGYo6gyRDVeAGiXz+L7EWt7C+FV7yRM/oFAALFqjGsDgMtGcogKhwVQnNXBqio2fyuJ3diPpaltKBgczn/pHIrh4iy648zd9+8kipxLq9XblsUin1+9bVKXdfU1P53sbQUMFdoxuEdLog9suWTcz9tHbNUloWKUMkHA6MSumhI48mM3FqLKJRZeiAuo5PnFDnYPPmkRP5pFSNX1eXuh7mz1flTpcsUe7WkycL4h+PK/HXLtLZgi/4sxTt3tDCrW+ysSIwhFCCHA6rG3YyA0zhsCq5ePy4unBPnoQVK0rf3IFAwTrq7VXbaqu/vl6Jw2xLQWvdciP29x4enlBQnzzM0ch5cP11w/706uqR7wmHlYitWVPw99bWQt6PRo9bjejtQTTFCK5dhTgSHI760BgNdchEAnOWTLwpR3+/asT1oLBhqN9x4UIl8mMJdS6nIltAXS/Llqn3a3EUQgnqZK6JYLAw87gY7VrxntLpQgcTSKmOubNTVbnbtWtkgz80pIyjgwcLPYRbblHn74UX1P339NNq20hEiX85N9aZxBf8WYq+OVIpZS2Y5vhpDPRgXl1dIcRxssyfr0T74EEV0VBXp7qv5aipUYu2+vVNfiat/olkI4zcfhspwH78aUgkCFaGMa+4ktTNl6E9YMXHGg6r86h7O2+8AW96k37VJBFtJvKh9xGqBONXwBHlavCK42ybeONlaEhZ0XrOA6gGrblZLeP9dlLCkSOFAealS9V7pFRi6LrjXz/lGCtvvW5QT9cMWVCfrQegV61Sgt3RoepPa3cVqN/74EG17NunxP+mm9TxDw7C88+rc/yzn6klGFRx/itWqPcnvvEt3J89o/xJwSDGddcQ+9Ddp+17+YI/y0kmC4IP4+c+iUTUzbpvn7oRJ5MgShMMKqu2o6Pgtlm+vOBWKsVMWf3eQdHxshFGbr8Nbr+t8P8+dfPq+QLe3kwqVbAy9VjKkSNewVfnt/g75e/bYSY68eZMpdDVicR6ewvuiupq5XpYsGDibhLdsEO+nms+xNYblbN48am7XfS5L3Wt6+vQG2RwOvEO9q5dq3rB+/Ypl4+3B6LHwXbvVo3fDTeo1A2g7uPnn1du01/8Qi10HOei/R2stHPq4rNzuI8/TQJOm+j7gj/L0ZEDExnQ0/5q3TgcOzayFONkmTdPWXz796sue1WVurnHw2v1t7aOtPpbWqY3jFAPiopYjIxjEZpENsKmppFhnV7xfuihQmPZ2akee3tHvl+nhPBampnM6FDU8SbeTKbROhWy2cKYi05foecwLFgwOddIIqGuKxh9PehBfRg9+D9ZdO+iVIZPfX2fKcH3ogd7lyxRjf/Ro8rqP368sE1fnxr8ff11ZSjdfLM61htuUK9ns0r89+wbYkvDrWzh13jviX/In7Ccsvh9wT/3MIxCd9l784xl4evtGhoKQjUVLEtZ+z09yuLv7y903ccjEFDbQuH92vqbLqtfD4pKCUcTDdSHB6iOygllI9RC7Q1JDQbVDXnoEFxwgVrXnd9V8eCh6xZ6CPp7FE/jnwi60cpFKjGkxJqGFLq2XZjVqhuqWExZqYsWTWx+RPH+9u9Xzw1DCZk3KEAXCK+snPrMVFDHV+7a0D2xMz3/oJhoVN0ba9ao++L115XlrkNRs1nlBnzjDXUf3HabMqCCQbj2Wrj4a5/HtkK0RVcUdmqYk0/IPwl8wZ/FxGIFX7zOQzNWTnPTLHSha2uV4Pf1jR6MPBVqa5VFt2+fEsNYbHLhZ7W1aim2+gMBtZ9TvXn1oKiIxYgH0nSlK6nKdkxqULS9vSD4lZUFV4UWFl2U3rZHipDrFhoD/fucSvIz3WgdGZxHRSBFU6TvlFLouq4SeB1lAqqX0tysLPnimc0TQUr1e2m31pIlI92E3oZg4cLJNyTlGGs/3vKZs4WqKuWbv+IKdf08/bS6VzSHDsG//qsydO64Q/UuCQax7CyL054NXee0tmS+4M9iotGRg6/jXeChUKF7rrdtb58ewQfVoGhr5sQJ5atsaZlYPLxmLKvfOxV/wsfkGRRtikr2DdVyNF3J8usvmtD76+oKoq3/LxZ8bdkX96yEGGn1C3FqxpnRUEdiQAWWN4TzPpeiSJ5yPn4p1fG3thby1wSDKiRywQL125+qMOrfB0qH6nrDMFeunN5COhUV5ctXzrb5Cl6EUNfxe9+r/PsHD6rEbLrn19UF996rfqN3XPFOFvziv4GcsuxdBxwX46ZrTtvx+YI/i/EKqc5EOZY7pziVQUuL8jEWhwpOlaoqdUMeOKD2Hwopy2+ywqKt/mxWCUdnp1q8Vv94g5nFg6JN1VE6W96E2zyx0BAt+JmM+h51no6BDoe17UJ2Tyg81+kQYGQqgclibrqU1odeQ0TSGBUuMjEyksc+cpTMdx6EgUFkzsZpbaVrfzedl7+TPqN+OHS3vl6J/GQbzWJSKTVADep6WrBg9P6OH1euMJ3vabqJxcr3aMcKHphNmKZqCD/7WXVdPP00bN2qXstm4cGOW2DtDdx8/N9Z279ZXfCbNmJFo6S+ct9pGbz3BX8Wo3t2tl1qpu1oIpGRvnXdYHR05LuQ04hhqItZW3l79px6l9471d5r9budnVS9/Cx1VfaYg5neQdEI0L1bvX/Nmol9D1CfuWjRSLdHT08hi2IwqHzUOr+/FiJdGEZzSu7XBS1YmyLMb30B2dUzKpIn9+TTyJMdJCL1HI0up9+tIjcA4uVjNN1Yz4IF0zPj2XFGuiFKzcNwXRVyCSOjc6abaHRuWvjlCAbh1lvVMjAAX/mKvlYCPDH/Ezwx/xOsWtTNrY//ASMuoYceJvKP/3faRN8X/FmMNxZf5yYfi1KZEaNR5YOebsHXxOOwerWyCFtb1TGuWHHqN6XX6j/0wm56Ak30OmEWuV2EY+rLjTeYuXy56n30909MkGKxgo/au/2BA+r72bZqTHU+fW31g3K56cgXIU5t0PboUTAaGqi++vaSrzt7D0I4zLbg5UgMqsw+Gp2j1J98hsqNF07+A4uQspDxEsqHU3ot//HCdKfKWG7sQODUwo1nC5WV8Cd/op6/+GJhdu7ew9XsPe+/MLMD/P6B3x/O2Jr6P386KmXHqeIL/hwgmSy4dMaiVOTM/PlqUK04Pnw6EUK5dLQg7N2rGpjJjh1kXtiM/cTTyJ5eRG0N8yVYa9cw6MQIGnmFncBgZiCgBOvECXVzjXfe5s1Tg2quO9KNZtuF6lpVVYXYdW9jlkgUfP4w+UFbXYCjsXweNYSQSOBithAki4mLlBmEmHpdQ28oZUPDSJeWF50fxzRVg366B0zHut4nmoV1LnD55WqREv7uT7rJBetwgpV8acW9/NH+j0775/mCPwdIJtWA2XhWc6lBM90lb2srDJaeLiIR5UZpbVUi0t4+8cG8zAubyT3wXYiEoboKmUhAVzdOKEDVSk/Y2gTTErS0KDfT8ePjp5jQDWVXV6Gco55QY5pKxGtqVOhhV9fI75PNKgtfykIxkslw8qR6HCuhmLFyBfbrbxDRSZVyOWQqjXn+eAVXChSPhdgXXsYxR4VZRaPKnVVKYKUs5MIZq0GYbnTtg1INaCBwemfZzgRCwO8d+CMAXmEdYXrHecepMQe9YecWlqXcBDo391iU6+bW1Z2aq+FUWbiw0Ljs2zcyCqYc9hNPQySMjMXIGUFV9LmyEnnoCDKRUAnHEgnkUAJz06Xj7k8IZbkPDk5MhE1TWbC6wpTGspTlr3tH7e0jfwfXLcyVGC9sthR9feOLV+CtN2A0N6k8+4kkEjCamwi89YYJfYae2CUTCWRdHfs7qzjwo524nZ0sX64ax1Jin8kUxH7p0jMn9ppyBo5lnX2C7+UiXmMtbeNveApMSfCFEP8ghNgthNghhHhYCFFdZrvDQoidQohtQoitU/nMcw3tYtB1SseinODX16vHU82vcyqEQsrar65WkTe7d48tvLKnFyIR+qlmM2+ijQWIqkqEZSFiMWRXDyIWIzCJ2acq0VkhTnws9BhHMDjaraOLsUOh+paXXK5QvrFUsq9y6NQG481etha3EPqNOwlcuAFr2VICF24g9Bt3Tvg8eGcjHxxqRoaiLKgeYvmJZ8r64Xt6CqkxVq+emUyo5QRfV0zzmTxTdek8DfyJlNIWQnwB+BPgj8ts+xYpZVeZ13zKEI0qYdA54MeiXN4Sbb21tU0scmU6aWpSluGBA0p4a2qgLj061FLU1iATCWpifdTTxUFWMJgLsqbRIPTuO07585csUa6YoaGx/b7aYkyn1SCunrjU21uopQvqtygeC7FtJfr5RJwTRpeRnMgA91TqouqJXQCL4x0EDQcpjbJjIQcOFNxYY40tnG7KGTg6hcjZRsVDDzB45/tLrp8upiT4UsqnPP++CLxraofjU4y2Ng2jtODr0DVd3q8cCxcq37q2RM8kgYBqaLq6oOONTk5ufo0lNRmCnlBL4+KNOE/+FAmsieykNTOPQ3IZg6uu5topfHY4rM5ba+vEGruurpG5ZbJZJfjaJZZKjexJeSNz9ADvRNDv0QnnTid6NjKxGEEj33KVGAvxzpqd7IS600Gp63k2za49HUynuJdiOn34HwYeL/OaBJ4SQrwshJj+oeezGG3J5HKlw+C0wAgxtqWordvpyK9zqtTXw5Ljv0JEwhyVS0g5IUQshojHMNIZAu+/S/nu+/pZFOli41vn4TS18NOfTt437kXH+OsB0nLoQtfF0UWWVfDTO87ohlfXodWFOSbC4cPq8XRaqvaRo2S+933sg0ewX38D50R72bGQgYGC2K9cOfNiD+UNHJ9TZ1wLXwjxE6BUFPefSSl/lN/mzwAb+HaZ3VwlpWwTQswDnhZC7JZSPlPm8z4KfBSg5XRM4ZujJJOlffT6BhBiYoO63nTAM4Ho6mJlvUvSCRMy8yEY+VDLyLvvIHTFpuFtY0B1SmUW/PnPVa6SU0m3axjKn9/Toxqdcj0hPeu2WOxUDWH1vFQlsUymfA+sFHqW7mQK1EwWbwZOc2kLbiSIe6wV0hnMpYtHTOw6dkw1WPH47CrMXaox9AV/aox7iUopxwwFEEJ8CLgNuF7K0vaNlLIt/9ghhHgYuAwoKfhSyvuA+wAuueSSCdpLZz/lBF+7c2D8Qd3589UM1HK9hTOBdi/EYp47d4xQy0hEZRb8xS+U8F900dghjOWYN08J/oEDhXJ2o44tfx4zmZGhmYahnusZtsU9qWxWib0u/jEeOpXu6Yo0sY8cJXP/N5EDA/RVLiZVX0NLczNGZSUyp0bOcz9+gmx9PYfnX43R0MCCBbMv8sUX/OlnqlE6NwOfAW6XUibLbBMTQlTo58BNwGtT+dxzkbEs/In48KFwA3lzd59pzE2XIocSkwq1NE24/nolqq+8UsjHPllaWpRgF6c59hKLKXH3NoiZjHqfPr/FqaqzWfX6RMMyh4YKEUTTjbbsc30J9ldcRJddjXXsAG5fH242h7PjNWQiQaq6iQNd1dibt7AkcHTWiT2U7s0ZxtyeZTvTTNWH/29ABcpNs00I8TUAIcR8IcRj+W0agWeFENuBl4BHpZRPTPFzzymCQWWVlxN8vUxkglN19diCd7qxFrcQuP3WUwq1fPOb1eSfvXtVjdHJol01Oj1AKebNUw2L91z29SkhL9Ur8vYEtNiPZeXrOQmnUvZvIjibt3DUXMaL1W8Dx2Gx1UqD2YPbdgL38BFERZxOs5njqXqCEYtVdV2Il7ecnoOZIqWiqizr7JllOxNMNUpnRZn1x4G35Z8fBDZM5XPOdaJRZUWWc+noos4TCe+bN08JWD5oY0aYSojhBRcowd6/X4VMeksOToSVKwuTwUpNJCoVb64FXAu+d/anjsxx3YJvvtTArqazs3TOo+kgm4WfH1qCiEaor+lgxbHXEMEA0rKQAwOQsxHrzyfnmjSEB6gOJpBy8nn3zxSlhD0Umn2up7mEP9N2DqAt03IRHRP14Xu3PVW3yGxg8WLYuFFFzvz0p5Ob7GSaKr9OZ2d5S7x4ti2obbWIlxJ8UGGNUpafYKbLVU6mcMxEOXRIldozolEuCr/G+vndmKtXQjCIHBhCVFZiXnA+RjDAgmgv1cHE8EFNpljMmaSUsPuCPzV8wZ8DaF9mKetTp1eBiWeo1DM7JxpCOBupq1PVhUAN6OqwyYmgS/DpmaTFNDaWbly9IbIanXrBdQsNTznBP3q08J7pIptVjd7Bg8pNdN3tMeLpblUFrKoSc0kL1oqlhD78AQJvvXHS4yczSamatZHIzNSyPVvwBX8OoAW9lAXvDQecqODraekTyXEzm4lGlV8f4LnnRhcZH4v585VYlspfX1MzWvDT6cJ5TqdHhsOCEnnbHin8XnQjsHjxxI9xPA4eVFY9wGWXKXfXWGMkUxk/mQlKCXs4PDNpHs4W/GyZcxzvYO1kSsxZlppVqvPszFUsC667Dp55RkXwrF49sVjyykoVrVSqUEowOFpUhoYKYx65XCE6SveSdHy+Fv5itHV/KvMIislmC0I/bx6sXz/y9bHGSKYyfnKmKTXG5Iv91PAt/DlEKctRl+HzxuNPBO1Hnmw639mIEMrSr69X2R0nGsGzfLl61AVMNMGgchl5z2c2O9Ka1695J2SBagyKU/rq0M3pmPBWbNUXi/3ZxFhuNZ9Twxf8OUQpcfaGEE7GwteWki5+cTawYYMS8RMnlItnPLyFUrzjGcGgcnsV+9q14HsbXsdR67XIZzKj3US6EPhUYu+1r/7QIdVwXH/92T94WcqAmSv1bGcrvuDPEcp1Zb2lDydbVrCyUrkqziaWLFHCP9EIHp29o82Tftw0C4VQvJQKpdQhmK6rHh1ntOD39k4tdtxr1W/adHZb9V5KBRX4M22nhi/4c4RotLQF743Nn6zg69S3MzkR63RQX6/KxsH4ETxCqPMwNDSyBxUMjrbwvULude/oyVc6BYN3O53zXkcGTYZSVv25NOmoVI/Wn2U7NXzBnyOUE/zKysLg4WRcOlDY3huTrzMspr5yn8q0eOToqR/0DBKLjYzg6esrv63OzeMtlBIIlB64LW5UdV4iPenKtkc2oMePT3xSnJcDB85Nq95LqfKG51KDdzrwBX+OEI2W7s7W1BS6vpMVFVDWrR549JbCE55c9XNV9HUETyAAL7+scuKXQ5dk9Lq4iq3JoaFCI6nPuZSFpGm6MlYiP6dJW/qTqSWcySir/vBh9duca1a9F+9AueZcPRfThR+WOUfQQqPDATXNzYWQv1MR/JoaOPFaJye3vkxs268gYGIsWUKrbMAyXSrCXVS8uGXOhPIVIwRccw1s26YieAYGYO3a0duFQqMLpRQLfql8OlIWIkd0AjVt4euJXRONLDlwoJAnf9Oms0/ccrkcra2tpCc4S85x4MYbR65LpeCNN07Dwc1BwuEwCxcuJDCJkWxf8Oc4TU0ji6BMFvvIUdyXdtIRjrAU5Rty9+5DLmkiGakmIeZzom2Q4G61fTSqIlgqKk6tgZkpNm5Ug5+HDin3zhVXjD5fy5apxGw6cknnvPE2sMUWvm4EdKZM0yxk15RyYr77dLoQVdTYCOvWTeWbzl5aW1upqKhgyZIliAlcrNmsmivipalpbl13pwspJd3d3bS2trJ0El1IX/DnGF4BEkKJ71RSJDibt7CgKkcbNchYDJHNIIIBFna9gnX+echEAmd+BekGFa+eTKrlxInCPsJhlYWzomLy4whnkmXL1DHu2KEKqlx77UjxMIxCERRtueuB2GK8aZC1duVyqpGw7UIK6vHSAOzfX8jeeTZa9V7S6fSExR4K5Ti917cv9gohBHV1dXROsoSdL/hzDO+9Yhili270DqZo7Rogmc4RDQdYWF9JTUXpKZ5uZzeR+lrm2X2Y85tw9+1X2RV1zpWhBKHrriVWNzK7pOMo90hfn7JQ29tHxvSHQqoxqqyc3twxU6WhQUXwvPiiEv03vWmk66ahQQl+V1dpwdfPvT58Pektm1WDxdksDA6OHXfvteqbmuD886f3e85WJir2MFrw/ZDMkUzmXGpm0a3oMx62zOBKE1A/tGGAECNN6t7BFHuOdRGwTCIhi2zOYc+xLlYvqi8p+roCVVVMQE01YtUK3ENHEIZAxGIjSuF5MU3l//dWn3IcJXT9/crX2tFRmHQEyvVRXa0agZmcQBOLKb/+M88o0S2uotXSoqpj9fb1ksuFcJwApqlMS9se3YUZduVksxw52kkuJzl07Ai33zafxsbR3e1zyaqfClrwVUlskEiSGYegZWKZvql/KviCP0foHUzR3t+HEA1og94M2PQOZoGCkLd2DRCwTIIBJUz6sbVroKTgm5suJffIo+qfaBQRCGDMbz6lpFqmqQTdWwTcdVV0S1+fcgV1do4spG5ZqidQVXVmp80HAiqC59lnR+fgiUahs7uLrS8fIRRYimkaOI4W/CxCGEChW7V7bzvIGjLpHP2JHsKBGnYf2MdPn/lvPvWJu7jtlquAc9eqnwqG4eI4ysARQCZrk8namKaBKQSulEgpEUJgmobfGIyDL/hzhNauAQzLQQiJtvCDEZtX9rWTSC/Ediy27D7Byb4k4aBJZTREJKTM6IBl0DuYZuehk6PcPNbiFrj9VpzNW3A7uzEa6spa9aeCYSiL3uvLllI1Av396rG7e2TmTiGUxV1VdXqTZQkBV1+tBH/PHtU7Oe889doLWx6ls3s+jfWLEUJ6oqRMlK0J5C3QTMYmlewjFIyTyzgETIe+gS7S2Sz//NXvsmzpAoLW0mGr/vLLZ674zGzFdlyytoPjuEgpMW0bN+0iHIthg0YUzrttu9gej4bI/yauKwkHLV/0y+AL/hzhZM8QfUNppKhHT5+wogN09qfI5lxc1+XAiT4AhlLQPZAibBnEo2FMQ5DK2sTCgZJunjOdQVEINXjqzQWj49f7+5Xw9vSoxUt1tWoEpiPjpJeLLiqERPb1KUF+5vmXqKm4jIa6C3CkjYm3+yERCFzpAILevhSm4RIMVmCYIZKpPrq62kikummsb+Lhh4e4/DLfqveiBd51JQJwpcxHOkmEncNMDGGLCAJdNxIM4Xg7ViOey/w+LdMgaztjCn5rayvPPfcc73nPe07HVzvjnzMZfMGfA/QOpugZTOFKMAwHF2W5R2vUDJ9SUTpSQirnkupXZZYsQ5AKB3GDFgPJDOmszSv7jnPRyvllB3TPJEIoX7bXny2lGgvo71dLX9/oGbPaHRSJTG1Qb/ly1QDt3KnSMZxoPwnuMVLpfgJWBCUpIn9cMj9gpt0JJo7M4UobpEtXzwE6e/axsPki5jddQGdPr2/Ve7Adl3TWBiSuK3HcwgUsJYRSKVxh0J/JcXJwgKFcH+FAgOa4SSAyxkmU4Dgu410GP/3pT9m1a9ekhNhxHMxJhqCdyuecbnzBn8XoaJv9bT3oe0KYhXjAcDwz4X05ruRE7xBB0yASDmAI6E9kee61ozTWxFixoG5WCL8XIZQ/PRotVOmCQiPQ11doDLxUVKhGIBabXCMwb54aRP3xY72sX/MuevqO0N6+k9CiGJYZxDD0SLNAOxEEYAgL20njOlkcN4frulx8wd2ApKNrP4PJZ4nFrpnSuTibyNoOUBB6OfxHYbgOvVmXIz1pLEMQtKLkHIdjPUMsilpURsv7+aRUDcpQKothiFE+/WeffZY/+qM/orq6mieffJL3vve9/OAHPyCVSlFRUcHDDz9MQ77C/J133kltbS3bt2/ntttu44477uBjH/sYfX193H333dx7773s37+fQ4cO8alPfYq2tjYMw+Bb3/oWnZ2dIz7nBz/4AcuWLTsNZ3Ny+II/S9HRNqlMjkyuEBdoWi654ec6NnB8VZMo0c/iEpKSVNbGMITy7w+lx4zkmW1EImppaiqsy2QKjcDgoFq8xOOqEYjHyzcCu3Yf4pvfeYynfrqVhc1vZemiK3h99yOEI1UsWXj5sOBr6z5nZ7AMCwGkMwOYZoiBoXYqY/NI5QbZtfcx0uluFsyvKf2B5yiuqyx7yPdOi3qormHSPpgkaEgCpiCNwDJNBIKOvqExBV/v03VdQJDOjvTpX3XVVVx66aX84z/+I+vWraO7u5vPfOYzAPzVX/0V3/ve9/jkJz8JwM6dO3n3u9/Niy++iG3bXHbZZXz961/nwgsv5OMf/zjr1q0jl8vxkY98hPvuu4/ly5fz2GOPcc899/Cf//mfIz5ntjAlwRdC/CXw24COu/hTKeVjJba7GfgyYAL/IaW8Zyqfey7Q2jWA40pO9iVHrDfMkSkEb/rqPXxzxRdImRXc8tV7cICnPvHZsvt1XJnvTkM4YGGaBrbtErDMspE8c4FQSFno3iIj2WyhBzA0NDoVtHfW8O69h/j//vr/sWvPYQYHk3R0fYcL1t7BwgWXsP/wL2moX01VvHBupJQ4ToZUqpeTnbtJpntIpfrp7NmHY2fJ2gkq441UxhvJZeFP//xR3vLmi1m0oIligsHCEggUnlvW2Rd7bjsujusW3Dgl3JG5UJhUrp+w5fHDSzCDFunsxCrWO67ElWCZukfB8JjB7t17WLFyFQDf+MY3ePDBB8lkMrS3t/N//+//BdQksZ6eHv78z/8cgB/84Ads2LCBCy+8EIC1a9cyb948fvjDH/L6669zxx13qO9n21x99dUA7NmzhzXF5dRmmOmw8L8kpfzHci8KFdbwFeBGoBXYIoR4REo5wbpE5xbajXPkZD/pEgVXQ7EMyXzt1pu+eg/Kq1gIWzPz68cSfdt2iYUDBCwTx1ViH7AMEqnciLJ9pR6Ln3v/1+kEihc9K9X7f6l96e2KX5/Ieo3epnhb/ZrOZplMFtLvSgnPvdhHyLqBDWtyDCW6yeQGSSZ7qIw30ty4jmwuicRF5AfMpXRJZwYYGGznZM8e6qqXkc0O4ro2S1quIBapI2ense00OSfN0SM5fvS/h9h0SZjGedUEAnoeRUHkA4GR9Q3GwrJGNhTeBmO2NhKuK0llcoXfpMwMcdcKEIyEsDNZggbag0ZOQjg4MT+6chNJbEeC45LNOQgBfT09VFZVYrvwn9/4Bi+99BI/+9nPiMfjXHPNNZyfH1V//fXX2bRpE1Z+1uCOHTvYuHHj8P5fe+01br75Zl5++WU+//nP81u/9VsjPr+rq4uqqqrh988WzsTRXAbsl1IeBBBCfBd4O+ALfhHeSVPhgMlQiTz1V/zyP3lsyWcBiZ6C1W+HIBDjS+f9V2HDZ4re6AATuFeePOWjn8s4RIIXsHJJ8XqJ69pkc0l6+47S3XuEuupFgEEi1c3BI88yMHSS5nnnEY3U01C3HMsMEwhEMIT+dQqDvQD79sLBA6WPQojCYhhqCQTU/AbdGOjHUEgtuqh3KDSy4SjVaJhm6Z6EnlF8urFddzgax0DF0JejobaSIyf7sE0DHInjukjpsqB+nFwVHiSFyVtG/ic4dPgQzU3NCAHbd+zkyiuvJB6P8/3vf5/nn3+e9fk81Dt37uSCCy4Y3lddXR179+4FYNu2bTzwwAP88R//McePH+fJJ5/kN3/zNzEMg507d7Ju3ToOHz7M/FMpgnCamQ7B/10hxAeArcCnpZS9Ra8vADwZ12kFNpXbmRDio8BHAVpazlyo4GzAO2mqKhama2Ck4t/01XtImvkL3i1EI1x4+A94dem9BfNQ5E0iKAi9FvsJCv+5RPmqWAJhGASsCFXVi0gmu+jtayUaqaFv4Bht7a/SULeKpob1VFc25338xeb15FIJeIuigxqbmA50A2KahYpekYhya0UiaoA7HFaLt7cBattSbqfJGq8y/+UEahzEFbKslV8ZDbG4sZqOviESGVtF6cyrHtd/P/pD1WfpPvDKVavp6u7m0osv5Mv/+hV+9+Mf49vf/jY33XQTy5YtI5YPpdq5cyeXXXbZ8G7uvvtubr31VtavX8+1117LkiVLWLZsGR/+8If5+c9/znnnnUckEmHdunU88MADrFmzhq6uLtatW8d9993HlVdeObnjPk0IOUYrCyCE+Akw2vEIfwa8CHShfra/AZqllB8uev+7gJullB/J/383sElK+bvjHdwll1wit27dOpHvcVaw+Y1WIiFrOEfG64dHJka65av3IBF8+bxvQDbDr3V8jZCTJOikCLtDhO0kAZlGIHn8E59lxUu/IpBJkwuGOeksBsC0c9hWgOSGi8jkJJYRIGgGCJsRnFyQREK5O9Jp5QPP5ZT4OE7BTXJ24YxbBhGUC8dxsgwMnWTXvkcJhSoIBeNUVS6gItZIPFqPaQYQwkBg5Gfjlmay4X2zFe2O0r2MWKwwV0In04vFCg3G/gOvsWT5CvVeBLbrlhV8TWXnCTrCKj3FvLTKNz3Q0DzWWzwHqCx7Q4h8PysfSgtYpoEQgmhoYjk+hoaGiOdjhv/hH/6B/v5+/vZv/3Zix3EaeeONNzhPzxbMI4R4WUp5Santx22jpZQ3TOSDhRD/Dvy4xEttwCLP/wvz63yKiIYDZHMOwYBJ/2DpnOGGCjaGwV/SHimdFtUFOg7Og8QF2IEg5EQhAZXhEsqkSCYtKqIm4ZBg2cKqYbdAMFiwAk2zkLxKL7atrM5cTi3ZrGoc9LpsVj3PZNR672u2rRbXLSxef/xsbkyEMDAMC8uwGBg8TnP0Auqrl2EFIwStcD4cM4dphPKCbo4Q/YJjZ5Y62E8BKdXvrYeaxkvceNNNBon+sSz0wgVgODaGzDFkFFw42kFW2XlifNHPC71liuHGRXr2LyWEAhNveL/0pS/x3e9+l0AgwJve9Ca++MUvTvi9s4mpRuk0Syl1otx3AK+V2GwLsFIIsRQl9HcB753K556tLKyvZM8xlQC8e7B8odk/3DuiEzXCSMoKiyd/6zNU2X3UdbQjs5Cx4tgEyREE28UORshmBBkjQCLhEp6E/9brFjDNwuBhTU3Bl2xZIxfTVILubSR070GXsdO9CNsuvK4bE93b0A1GNqv2p3sdeoHSg7berIvF6XYzmVx+8o+NdB0c18F1cyTTfao2gLTpHziB7aQYSnTRduJVFjZfRE1VC4YZIJnsprt7P8FgFFdKIuEqglZUhQUKicDAME0s0yIWi1JfV0coaA4fo/e4yw04l3s+9xivscu/LsE1ArgEsM0owyspiP5ECAVMFQ0kwDSM/DgABEyD0CTTL3zuc5/jc5/73IS3n61M1Yf/90KIjajf4DDwMQAhxHxU+OXbpJS2EOJ3UeOBJnC/lPL1KX7uWUlNRYTVi+rZ39ZNJje6grN2v3tvG5lf743KsXCxglmyK2pYtv0lsiKEHQxhZTMEcxlaN1zOvJY+quMRggGT9UsLVpTjFKxxr/XutdS1OGcyKuLFGyUzHkIUGgFvo6GfR6OFhsLbu/AKu22PFj3va/o7eHsc+jvphkI3Fnv3tdPV248hAkhhYiAQQmDbKVyponKEYZBODjA41E4q08+iBZdQU7WYTDaBaZi4To72rl3U1iwlmewhVF1B0ApiWSHU8CQEggbzGmqJRgq5mIu/Q6kIKe/zUo/FEVDFr5eKairuVZ3ZBqT46h2bWLqDEEkCE5b5wqdIIBy0yNoOUko1PnaOJ1ebkuBLKe8us/448DbP/48Bo+LzfUrjuJJgwMRxHDxzrnjqE5/1hGLmt6V83P1g/TwObriMxoN7iA4NkIxX0nreBmRjMzXBADnbYVnzyElBWnhP+djzVroWXG/joV09Wmz1tt514zUeQpRuKLy9DZ2uOBweKW56vV6On2jn+S0vcOBIK4ZhkExmEcLCNCwMM4BpBDDNIJYZwHVy9A+eIBCIcOjoC8TXziMerSccjGFZIVwpaWt/lWwuQSLVxWCik9qqFkLBKMuXLlMRJkVRG3rwdLwEcaUahlL/TzaMtnibUutLvVZuO29vStf31Y2K4xTOe2FbvcPylr9FbtJiD2AYAteVWKZxTgt8MbMrSNRnOFKnriJCR18SGKmAY8XXl2Kwfh6D9YXZSAHToCEaojIWGrMwyqnijQI5Fbyun1INRy5XEBJtqWuLvrghKYduKLa+coyGugWEgnFe3vYcPb3tZG0XywoTsMKFRzNETfVi5tWvpKNrL0PJk+x844dsPP/dREIVQAMgCIcreXnHt+nqOYAATnRsIxzpZuXqC6iqjPOe96whl1OFYwYGSkfgGEYhP1BxTd25zhtvqBKOmmTGGc5LJF1JxlbJ0So7T5RtAiYk/QJMQ2AYZ894yXThC/4sI5nOEQlZVOeFuGsgSdZ2p9ztjoUsQkGL5roK1i9tHP8NM4Su4jWVBsPbUBQ3GnrMAKCv/zjV1ZVEotVcfeVV7Hh9P4lEipztIMhx7Ph2crbNwuaLkdKlq+cQFfFG+geP01C3iqOtL7Gs5UqCwZgaoBWCyy/6LV7e8R2S6XaWtlzAqhUbSQxlec8dKhI5EFCVw7zVw1y3kCl0aAh6e9XiJR5XKabj8bOnzF/QMoeTqAlDELQMsjll4IycvTBS6MccsBXkLXqVR8dnJL7gzzK8kTrVFRGqKyIMJDIMJtOksnZ+avrE92cKiIZVat9QwGThJCauzEUMoxAvPh6NTUP0D7RTVRmnLgTZ3AGisQCRSIzLL13LCy9lOHTkOOFQBa6bo7t3P67bQlXFfA4dfY6lLW/i+MmdzG/cSCAQocKYR9rs46rLPsRgaicuu5jfbLLxgmsxRRPt7SPz/3iPuThdNKhxE90bKJUaIhhUjUBV1cxWEDtVLNMY9rG7rsQwDCwLBuc1U9FxYpQ1X1LoReHBMASGXwhlTHzBn2V4I3UClkHOdjENwabzVGRra9cAbZ0DpDM2abt0ALllCKrjYWzXxXYkpiGorYywchZmxJxJbrnpSu69/2EAKuJRgsEAA4MJ1p+/HCFcVi5fRGdXLw11TSSTg5imQW//UQIBg+qqZg4e/SXLWq5BCIfmxosIWBFisVqqKmupqlrF8uV38OY3K/fRyZPKau/rU/l+xqp3q9ETobz5gXQt4YEBNWDe1aUWL7qWcDQ6e9MsaIp97LbjksrkSDTNh+F5WeqJKVSGUlcy3AUwhDilqJtzFV/wZxk6UsdbhHxZc82wUNdURIYbhXTWprMvgZM3hSwDqiuiXLyy2Rf2CbB2zVI+9uF38PhTz9N2vJN1a5dx4mQ3waCF67oEgxZLl8xnUeMqsvZJrr/uPaxeuZg9+46wZ+8AgwPzaO/cQSgI3b0mTQ0XYpkhkkmTykrYt09l7bz5ZuW7njcPWlsLtX4XLBht1Y9HqVrCUqrJcgMDhWRxxSmjI5FCorjZPO/LMg0ioYDH6i+4ZorX+QI/ecadaTuTnGszbSeDTrLWO5gmZ9sETJOayshpGYg9l9i1+9BwA7BgfgPXX3sV3Z2LWbNGCbQXbV03NKiKWSdOwK5dKiJJCFi6VAl+RQW8/e2FMo+OA4cOFQaWFy+e/ipeoI5D9wZK5OHDsgrlJ8/EAHGpWaE+U2PaZ9r6zE5qKiK+sJ8G1q5Zyto1hRnMXV3Q3TmyJq+mvl49dnbCsmVK2A0DXntNCeyhQ7B2rXLnfO97cMstqtEwTVixQg0gHzjAcK3bZcumt5B7MKiOUR8nFIrK9/ergeJSpSQnUjvgTFDc+N5y05UjfpvZwqlUw5op/D6Rj88YDAyoR2/pRS9aULWlv3EjbNhQmAOwa5cSeduGxx5TJRQ1gQCsWQNLlqj/Dx6EvXvHSuY2dXRR+UWL1GfrZfFi5SYyDNUgtLWpwu67d6vl0CFVaN4b5VQK+8hRMt/7Pqmv3Efme9/HPnL0lI5z1+5D3Hv/w/QPDNHcVEf/wBD33v8wu3YfOqX9efmv//ovLr74Yi644AKuuuoqAHbv3s11113Hxo0bueGGG+jKD4xcccUVHDqkPrOtrY2LL74YUNWwPvaxj3H55Zfzd3/3dxw/fpw77riDCy+8kDVr1vDSSy8BcOjQId7+9rdzySWXcNlll7Fnz54pH/9U8C18H58x6OnxJCAtQ329er2zUz2/7DL1nu3bVaTN9u1w4YXQ3g7PP6+E881vLuwzHFaiOzSkfPz79ikXT0vLmbOwdRUxb5y8bSuXVH+/+h6dnaPz5VRXqwYkEgHn6FFyjzyKiMcQ9bXIRILcI4/C7bdiLZ5c5tvHn3qeqsoYVZWqpdWPjz/1/JSs/MHBQb7whS+wbds2gsEgfX19ZDIZ7rjjDr797W+zceNGvvCFL/ClL32Jv/mbv+HIkSMsybfIO3bsGE6ZXFwN6+KLL+bzn/88t912G8lkUk2aHKMa1kzhC76PzxikUhPzr+u4+s5O9fyqq5Tob9um9vHqq3D55er1XbtUz+HWW0cOoMbjSvj7+lTjsGePcq00Nc2Ma8WySg8QJxKFsQFvYXn72VaksZqoMGhw+wnlUw07m7dMWvDbjnfS3FQ3Yl1FPErb8XEytI2DaZqkUik+/elP88EPfpBLLrmEBx98kKuuumq4wMnatWt55JFHOHDgAEuXLh3OXrtjxw7Wr18/qhrWD3/4Q8477zxuu+02AKJRlf/noYceKlsNa6bwBd/HZwyy2ZGCNxbFon/ttUqod+xQ1vuLL8LVV6v9HTwI3/kOvOtdoxuU6mq1dHSoHkZ/v3IX1dUx4wihGqZ4HLyZIrJZ6HiylcGKJlK2RTZgETJtiEZxO7sn/TkL5jfQPzA0bNkDDA4lWTC/YUrHH41Gee211/jf//1fPvrRj/KRj3yEkydPDhc+AWW9r127lp07d45Yv3XrVj760Y+Oqoa1bds2Lr/88lGftX379pLVsGYS34fv41MGndenYRIaU1enwi+7u5Xle/31yq9fVaVe/9WvVHz8xo3q9QceUNuWYt48WL1aDQZ3dipfuh5TmG0Eg9DQbLJUHGZl5QkqAvn03skkRsPkW6pbbrqS/oEE/QNDuK5L/8AQ/QMJbrlpaoVE9u3bRywW46677uK2224jnU6zYMECdu1SBfgOHjzIt771LT7wgQ/Q09NDdXU1oKJhHn30US644IJR1bCampp4/fVCPsjOvN+rubmZJ598Ml9QXTUkMx0V6Qu+j08Zkvn68fl7fsLU1iqx7ulRg7nXXQcXX1zoKTzzjBqYfdOb1CDo978P+/eX3pcQatB31SqVbuL4cSX8+thmE+amS5FDCWQigZRSPQ4lMDddOul96TkSVZVxTrR3U1UZ52MffseUo3Q+//nPs3r1ai666CIOHTrEJz7xCe6++26OHz/O+vXrueuuu7j//vupq6vjrW99K0888QTve9/7eOihh6irq6OxsXGU4H/oQx/i5MmTnH/++WzcuJEXXngBgA9/+MO4rst55503PDYgZngmnB+H7+NThqNH1QDq9def2vt7e1VIpm4AXn1VuXU6OtTrV1wBy5fD00+r3sSFF6p1Y2HbIxuH6Q7lnCr2kaM4m7fgdnZjNNRhbrp02H/vx+FPP34cvo/PNDFeBafx0Bb9yZNqsPPCC5V75he/UGGPL7ygLPx3vxv+939h61bl3rn11vKDtJalBnYzGRUqefCg2nb58snXmD0dWItbJj1A63Pm8F06Pj5lGBiYelKymhoV6qit/RUrlKAvzXsmtm5VLp53v1vFwh86BN/+9tjpnUG5d9asUfH0Uiqr//DhyRWj8Tn38AXfx6cMrjt5/30pampUaGVvrwq3bG6GX/s15ZcHFaf/4x+rhuCCC9R299+vYuDHIxZTwt/crGLl9+5Vfv5Z7Kn1mUFmQSfQx2d2kXlhM/YTT5NNX0ZF+BCZxEZCV2ya0j51w9Herh6bmuDXfx0efRRef72Qg+d971MNxC9/qSz9226DhQvH378umqLz+wwMqIihyUQY+Zz9+Ba+j4+HzAubyT3wXTKJLITC1CSOqv9f2DzlfVdXK6HXE6siESX6Gzao1/fvh298Q1n573qX6mE8/LCavDVR6utVKGdlpRoP2L17dOZMn3MXX/B9fDzYTzwNkTA9scUIIQjEIhAJq/XTQHW1cr/09ansmpalMmnmU7Rw5Aj8+7+rbT70IeWy+dWv4PHHJ/4ZQqhJUatWqbQNJ04o4U8kpuUr+MxhfMH38fEge3ohEiHOIM20qpWRiFo/TVRVKUHv71diLIRy3Vx7rXr9+HH46lfVBK0PflC5dPbvh299a3K+ecNQidlWrFCfceyYEv5StXR9zg2m5MMXQjwIrM7/Ww30SSk3ltjuMDAIOIBdLkbUx2emEbU1yESCiphBBfmagqkUonaC+RUmiJ55e+KEEvH581VCtepq+OEPVUjoP/8zfOpT8I53qEie7dvh3nuV5T+Z/PWWpdw82awK48wnf2RJ4Cji5dIx8z5nJ1Oy8KWU75FSbsyL/PeBH4yx+Vvy2/pi7zNrsW6+EVJpNUvUdZGJBKTSav00U1WlhH5gQFn1oPz573+/et7fD//4j+r5NdfAW9+q4vb/4z8Kk7cmQzCoInpaWsDt7GTvw69xqLsCWVc3nNnyVNMZny18+tOfZsOGDfze7/0eqVSKN7/5zTj5fNWtra08+OCDp/0YvJ+TzWa55pprsMeL050g0+LSEWq+8LuB/56O/fn4zBShKzYReP9diFgM+voRsRiB99815SidclRWFkS/rU2tW74cPv5x5YZJJOALX1DrV62Cu+9Wzx98UEX3nArRKCw/8QzNVQnsQJyOTDUiFkPEYzibt0z9S00T05Vb38v5559Pc3MzK1asGF4qKir4sz/7Mw4cOMBzzz3H9u3b+dd//Vfuv/9+3vnOdw4XN/npT3/KK6+8MqnPc06huIH3c4LBINdff/20NTTTklpBCHEN8MVy1rsQ4hDQiyo9fK+U8r4x9vVR4KMALS0tFx/R5YB8fM5itJVfUVEopdjfD//yLypax7Lgz/5Mrbdt+M//VHH3a9eeWuqH1FfuQ9TXjsjtIqVEdvUQ+eRHp+EbjWYyqRXsI4Xc+kSjkEwihxIETiG3vpcvfelLnDhxgr//+78H1HdetWoVP/7xj7nhhhuwbZumpiaeffZZbrzxRr7zne+wZMkSnn32Wd7+9rdTXV1NRUUFP/jBD3jllVf4x3/8R1KpFBUVFTz88MM0NDRw5513Ultby/bt27ntttu44447+NjHPkZfXx9333039957L/v37+fQoUN86lOfoq2tDcMw+Na3vkVnZ+eozxkcHORP/uRPeOyxxyZ0TsdKrTCuhS+E+IkQ4rUSy9s9m/0GY1v3V0kpLwJuAT6ZbyBKIqW8T0p5iZTykgY/iNjnHKGyUgn94KAqggLK5fOZzyixt234279V6y0LPvIRNfC7a5fKuDlZjIa60RnYTjGz5enA2bxFFVKJxRBCTFsP5AMf+AAPPvjgsIvkF7/4BUuWLGH16tV88IMf5G/+5m949dVXCQQCHDx4cLj4yVVXXcWll17Kj370I7Zt28ayZct4y1vewosvvsj27du58cYb+d73vgeorJiNjY28+OKLfPazn+V973sfX/7yl9mxYwcHDx5k3bp1w8VRvvjFL7J161b+8i//knvuuafk56xbt44tW6an5zWu4Espb5BSriux/AhACGEB7wTK9jmklG35xw7gYeCyaTl6H5+zCG3d68pXoFIofPazKsWD48Bf/ZVaL4SK1b/4YjUz9ytfGb/8oJfpzGx5OnA7u5Vl7+UUc+t7qaur44orruDHP/4xAF//+tf5yEc+Aiih3pCfFNHV1TWcGlmzZ88e1qxZM/z/N77xDS677DI2bNjAV7/6VcLh8KjiKD/4wQ/YsGEDF154IaCKq2zYsIEf/vCHw8VRNm7cyGc+8xnC+ZH44s8xTZNgMMjgRKZej8N0zLS9AdgtpWwt9aIQIgYYUsrB/PObgL+ehs/18Tnr0KLf1gZHtrTTdOg53M5uPr2iji8fuJVUNsxf/RX8xV+o7a+8Uo0B/O//wte+Bm/P/Bc1L/9ctQ6mCZdfSsUffHLU51iLW+D2W0dktrSuu3bWROkYDWogmXzVLGDaeiC//du/zT//8z/zlre8hWeeeYb7778fgNdff51169YBEIlESKfTw+/p6uqiqqpquOjJN7/5TV566SV+9rOfEY/Hueaaazj//PNHFUfZsWPHcCUtgNdee42bb76Zl19+uWRxlOLP0WQymeEGYSpMx6DtXRS5c4QQ84UQ2uHUCDwrhNgOvAQ8KqV8Yho+18fnrKSiAppkK/2/fIXWnshwfdjfi/8n1bEUULD0QcXaf/jD4PZ083DX9ewPn6+6AI4Nz77A4Je/UvJzrMUthN59B5FPfpTQu++YNWIPp7cHct1117F3717+6Z/+iTvvvHPYeg4EAkTy5cdqampwHGdY9A8fPsx8T4mvnTt3cuWVVxKPx/n+97/P888/z/r160flyq+rq2Pv3r2Aqoz1wAMPsGHDhrLFUYo/B6C7u5v6+noCU83kxzQIvpTyQ1LKrxWtOy6lfFv++UEp5Yb8cr6U8vNT/Uwfn7OdyM7NzK8eImVV05WpHPZh/07LY8OlBf/qzx0G73w/g3e+H/dD7+c3X/sDTDfLM82/wXMNvw6mBQJ4cfZE3kwUa3ELgdtvRcRiyK4eFS01xQFbjRCCD33oQ3z+858fdue89tprw9a95qabbuLZZ58FYM2aNXR1dbFu3Tqef/55PvShD/HVr36Vyy67jFdffZVly5YRi8VGCf7dd9/N1q1bWb9+PV//+tdZsmQJy5YtK1scpfhzAH7+859z6623Tvl7g18AxcdnVqKjaDJuEEO4BA1nRBTNf/3esxyuUdVS/uiND45472MLP8aAVcNdrf+oQnykpOJ735qJrzGC2VQAJZFIcOjQoVEi7+WVV17hS1/6Et/61qmfu6GhIeJxVZf3H/7hH+jv7+dv9ej7BHnnO9/JPffcwyqdXtXDtEfp+Pj4nHl0FE3YzBE08rHcHh/2O9u/xtrOpwD44pr/HPHet7V/XYk9gHSVL99nBLFYbEyxB7jooot4y1veckqx9JovfelLw6UPDx8+zOc+97lJvT+bzfLrv/7rJcX+VPDTI/v4zELMTZeSe+RR9Y8nDt267trhbW7u+jaL+l6jN9408s2ODcJQYi+By2dH5M1c5MMf/vCU3v+5z31u0iLvJRgM8oEPfGBKx+DFF3wfn1nIRKNozre3Q9/2kW82rXyUjlU2Ssfn3MQXfB+fWcqp1oet+O9vTP/B+JwV+D58H585SMVDpafXlls/W5jNQSJzjVM5l76F7+MzR5nt4l5MOBymu7uburq6ETl8fCaPlJLu7u5JT8byBd/Hx+eMsHDhQlpbW+ns7JzpQzkrCIfDLJxIwWMPvuD7+PicEQKBAEuXLp3pwzin8X34Pj4+PucIvuD7+Pj4nCP4gu/j4+NzjjCrc+kIITqBM1nyqh7oOoOfN13M1eOGuXvs/nGfWebqccOZP/bFUsqS1aNmteCfaYQQW+dikfW5etwwd4/dP+4zy1w9bphdx+67dHx8fHzOEXzB9/Hx8TlH8AV/JPfN9AGcInP1uGHuHrt/3GeWuXrcMIuO3ffh+/j4+Jwj+Ba+j4+PzzmCL/g+Pj4+5wjntOALIR4UQmzLL4eFENvKbHdYCLEzv92MF9kVQvylEKLNc+xvK7PdzUKIPUKI/UKIz57p4yyFEOIfhBC7hRA7hBAPCyGqy2w3K875eOdQCBHKX0f7hRCbhRBLZuAwi49pkRDi50KIXUKI14UQf1Bim2uFEP2ea+jPZ+JYixnvdxeKf8mf7x1CiItm4jiLEUKs9pzLbUKIASHEp4q2mflzLqX0FzWO8U/An5d57TBQP9PH6DmevwT+zzjbmMABYBkQBLYDa2fBsd8EWPnnXwC+MFvP+UTOIfAJ4Gv553cBD86Cc9wMXJR/XgHsLXHc1wI/nuljnezvDrwNeBwQwOXA5pk+5jLXTTtqAtSsOufntIWvESo597uB/57pY5lGLgP2SykPSimzwHeBt8/wMSGlfEpKaef/fRGYXH7XM8tEzuHbgf/KP/8f4Hoxw8nepZQnpJSv5J8PAm8AC2bymKaRtwPflIoXgWohRPNMH1QR1wMHpJRnMkvAhPAFX3E1cFJKua/M6xJ4SgjxshDio2fwuMbid/Nd2vuFEDUlXl8AHPP838rsu+k/jLLWSjEbzvlEzuHwNvmGrB+oOyNHNwHyLqYLgc0lXr5CCLFdCPG4EOL8M3tkZRnvd58L1/VdlDceZ/Scn/X58IUQPwGaSrz0Z1LKH+Wf/wZjW/dXSSnbhBDzgKeFELullM9M97F6Geu4gf8H/A3q5vgblDvqw6fzeCbDRM65EOLPABv4dpndnPFzfrYhhIgD3wc+JaUcKHr5FZTLYSg/BvRDYOUZPsRSzOnfXQgRBG4H/qTEyzN+zs96wZdS3jDW60IIC3gncPEY+2jLP3YIIR5GdfVP60U43nFrhBD/Dvy4xEttwCLP/wvz6047EzjnHwJuA66XeedmiX2c8XNegomcQ71Na/5aqgK6z8zhlUcIEUCJ/bellD8oft3bAEgpHxNCfFUIUS+lnNEEZRP43Wfsup4gtwCvSClPFr8wG86579KBG4DdUsrWUi8KIWJCiAr9HDXo+NoZPL5Sx+T1Wb6D0sezBVgphFiatzruAh45E8c3FkKIm4HPALdLKZNltpkt53wi5/AR4IP55+8CflauETtT5McQvg68IaX8YpltmvRYgxDiMpQWzGhDNcHf/RHgA/loncuBfinliTN8qGNR1lswG875WW/hT4BR/jYhxHzgP6SUbwMagYfzv5MFfEdK+cQZP8qR/L0QYiPKpXMY+BiMPG4ppS2E+F3gSVTUwP1Sytdn6Hi9/BsQQnXXAV6UUv7ObDzn5c6hEOKvga1SykdQwvotIcR+oAd1Pc00bwLuhv9/O3dswyAMhFH43SJsww606cgKmSSTsAht5oAmklPYIkiI2sW9r3R1suX/JOtk1viPGr+AAaCU8qY2p2dEfIEdmHo3Km7OPSJmOOpeqJM6H2ADHp1qvWhNaqTdx7Z2rr37nvu1giQl4ZOOJCVh4EtSEga+JCVh4EtSEga+JCVh4EtSEga+JCXxA5ZhPY+LMC8IAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_ot_map(neural_dual, data_target, data_source, inverse=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We further test, how close the predicted samples are to the sampled data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First for potential $g$, transporting source to target samples. Ideally the resulting {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` distance is close to $0$."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sinkhorn distance between predictions and data samples: 1.4665195941925049\n"
]
}
],
"source": [
"pred_target = neural_dual.transport(data_source)\n",
"print(\n",
" f\"Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_target, data_target)}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then for potential $f$, transporting target to source samples. Again, the resulting {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` distance needs to be close to $0$."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sinkhorn distance between predictions and data samples: 6.814591884613037\n"
]
}
],
"source": [
"pred_source = neural_dual.transport(data_target, forward=False)\n",
"print(\n",
" f\"Sinkhorn distance between predictions and data samples: {sinkhorn_loss(pred_source, data_source)}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Besides computing the transport and mapping source to target samples or vice versa, we can also compute the overall distance between new source and target samples."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Neural dual distance between source and target data: 21.16440200805664\n"
]
}
],
"source": [
"neural_dual_dist = neural_dual.distance(data_source, data_target)\n",
"print(\n",
" f\"Neural dual distance between source and target data: {neural_dual_dist}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Which compares to the primal {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` distance in the following."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sinkhorn distance between source and target data: 21.226734161376953\n"
]
}
],
"source": [
"sinkhorn_dist = sinkhorn_loss(data_source, data_target)\n",
"print(f\"Sinkhorn distance between source and target data: {sinkhorn_dist}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"vscode": {
"interpreter": {
"hash": "a665b5d41d17b532ea9890333293a1b812fa0b73c9c25c950b3cedf1bebd0438"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}