diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 120924266..a9bee6af7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,8 +9,29 @@ on: workflow_dispatch: jobs: + remove-unneeded-software: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - run: | + sudo rm -rf \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/dotnet \ + /usr/share/swift + deps-torch: runs-on: ubuntu-22.04 + needs: [remove-unneeded-software] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/notebooks/cms/cms-validate-onnx.ipynb b/notebooks/cms/cms-validate-onnx.ipynb index 3223575a3..5d7775d68 100644 --- a/notebooks/cms/cms-validate-onnx.ipynb +++ b/notebooks/cms/cms-validate-onnx.ipynb @@ -1,5 +1,25 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "db5a5a5d-8f56-45a8-b649-7933a777f82f", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install onnxscript" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "700d6b0c-9ed2-4dea-be29-a3a3e2321e95", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install onnxconverter-common" + ] + }, { "cell_type": "code", "execution_count": null, @@ -11,7 +31,7 @@ "import pickle as pkl\n", "import sys\n", "import numpy as np\n", - "import tqdm\n", + "from tqdm import tqdm\n", "import tensorflow_datasets as tfds\n", "\n", "import awkward\n", @@ -58,11 +78,11 @@ "outputs": [], "source": [ "#tfds datasets are here:\n", - "data_dir = \"/scratch/persistent/joosep/tensorflow_datasets/\"\n", + "data_dir = \"/mnt/ceph/users/ewulff/tensorflow_datasets/cms250\"\n", "dataset = \"cms_pf_ttbar\"\n", "\n", "#model checkpoints are here:\n", - "outdir = \"../../experiments/pyg-cms_20241101_090645_682892/\"\n", + "outdir = \"/mnt/ceph/users/ewulff/hf_particleflow/cms/v2.1.0/pyg-cms_20241101_090645_682892\"\n", "\n", "#Load model arguments from existing training\n", "model_state = torch.load(\n", @@ -74,7 +94,7 @@ "#this is needed to configure com.microsoft.MultiHeadAttention\n", "NUM_HEADS = model_kwargs[\"num_heads\"]\n", "\n", - "torch_device = torch.device(\"cuda\")" + "torch_device = torch.device(\"cpu\")" ] }, { @@ -122,17 +142,62 @@ " nn.Linear(width, output_dim),\n", " )\n", "\n", + "\n", "class RegressionOutput(nn.Module):\n", " def __init__(self, mode, embed_dim, width, act, dropout, elemtypes):\n", - " super().__init__()\n", + " super(RegressionOutput, self).__init__()\n", " self.mode = mode\n", " self.elemtypes = elemtypes\n", - " self.nn = ffn(embed_dim, 2, width, act, dropout)\n", + "\n", + " # single output\n", + " if self.mode == \"direct\" or self.mode == \"additive\" or self.mode == \"multiplicative\":\n", + " self.nn = ffn(embed_dim, 1, width, act, dropout)\n", + " elif self.mode == \"direct-elemtype\":\n", + " self.nn = ffn(embed_dim, len(self.elemtypes), width, act, dropout)\n", + " elif self.mode == \"direct-elemtype-split\":\n", + " self.nn = nn.ModuleList()\n", + " for elem in range(len(self.elemtypes)):\n", + " self.nn.append(ffn(embed_dim, 1, width, act, dropout))\n", + " # two outputs\n", + " elif self.mode == \"linear\":\n", + " self.nn = ffn(embed_dim, 2, width, act, dropout)\n", + " elif self.mode == \"linear-elemtype\":\n", + " self.nn1 = ffn(embed_dim, len(self.elemtypes), width, act, dropout)\n", + " self.nn2 = ffn(embed_dim, len(self.elemtypes), width, act, dropout)\n", "\n", " def forward(self, elems, x, orig_value):\n", - " nn_out = self.nn(x)\n", - " nn_out = orig_value * nn_out[..., 0:1] + nn_out[..., 1:2]\n", - " return nn_out\n", + " if self.mode == \"direct\":\n", + " nn_out = self.nn(x)\n", + " return nn_out\n", + " elif self.mode == \"direct-elemtype\":\n", + " nn_out = self.nn(x)\n", + " elemtype_mask = torch.cat([elems[..., 0:1] == elemtype for elemtype in self.elemtypes], axis=-1)\n", + " nn_out = torch.sum(elemtype_mask * nn_out, axis=-1, keepdims=True)\n", + " return nn_out\n", + " elif self.mode == \"direct-elemtype-split\":\n", + " elem_outs = []\n", + " for elem in range(len(self.elemtypes)):\n", + " elem_outs.append(self.nn[elem](x))\n", + " elemtype_mask = torch.cat([elems[..., 0:1] == elemtype for elemtype in self.elemtypes], axis=-1)\n", + " elem_outs = torch.cat(elem_outs, axis=-1)\n", + " return torch.sum(elem_outs * elemtype_mask, axis=-1, keepdims=True)\n", + " elif self.mode == \"additive\":\n", + " nn_out = self.nn(x)\n", + " return orig_value + nn_out\n", + " elif self.mode == \"multiplicative\":\n", + " nn_out = self.nn(x)\n", + " return orig_value * nn_out\n", + " elif self.mode == \"linear\":\n", + " nn_out = self.nn(x)\n", + " return orig_value * nn_out[..., 0:1] + nn_out[..., 1:2]\n", + " elif self.mode == \"linear-elemtype\":\n", + " nn_out1 = self.nn1(x)\n", + " nn_out2 = self.nn2(x)\n", + " elemtype_mask = torch.cat([elems[..., 0:1] == elemtype for elemtype in self.elemtypes], axis=-1)\n", + " a = torch.sum(elemtype_mask * nn_out1, axis=-1, keepdims=True)\n", + " b = torch.sum(elemtype_mask * nn_out2, axis=-1, keepdims=True)\n", + " return orig_value * a + b\n", + "\n", "\n", "class SimpleMultiheadAttention(nn.MultiheadAttention):\n", " def __init__(\n", @@ -185,38 +250,86 @@ " attn_output = self.out_proj(attn_output)\n", " return attn_output, None\n", "\n", - "class SimpleSelfAttentionLayer(nn.Module):\n", + "\n", + "class SimplePreLnSelfAttentionLayer(nn.Module):\n", " def __init__(\n", " self,\n", + " name=\"\",\n", " activation=\"elu\",\n", " embedding_dim=128,\n", " num_heads=2,\n", " width=128,\n", " dropout_mha=0.1,\n", " dropout_ff=0.1,\n", + " attention_type=\"efficient\",\n", + " learnable_queries=False,\n", + " elems_as_queries=False,\n", " ):\n", - " super().__init__()\n", + " super(SimplePreLnSelfAttentionLayer, self).__init__()\n", + " self.name = name\n", "\n", + " # set to False to enable manual override for ONNX export\n", + " self.enable_ctx_manager = False\n", + "\n", + " self.attention_type = attention_type\n", " self.act = get_activation(activation)\n", + " # self.mha = torch.nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout_mha, batch_first=True)\n", " self.mha = SimpleMultiheadAttention(embedding_dim, num_heads, dropout=dropout_mha)\n", " self.norm0 = torch.nn.LayerNorm(embedding_dim)\n", " self.norm1 = torch.nn.LayerNorm(embedding_dim)\n", - " self.seq = torch.nn.Sequential(\n", - " nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act()\n", - " )\n", + " self.seq = torch.nn.Sequential(nn.Linear(embedding_dim, width), self.act(), nn.Linear(width, embedding_dim), self.act())\n", " self.dropout = torch.nn.Dropout(dropout_ff)\n", "\n", - " def forward(self, x: Tensor, mask: Tensor):\n", - " mha_out = self.mha(x, x, x)[0]\n", + " # params for torch sdp_kernel\n", + " if self.enable_ctx_manager:\n", + " self.attn_params = {\n", + " \"math\": [SDPBackend.MATH],\n", + " \"efficient\": [SDPBackend.EFFICIENT_ATTENTION],\n", + " \"flash\": [SDPBackend.FLASH_ATTENTION],\n", + " }\n", + "\n", + " self.learnable_queries = learnable_queries\n", + " self.elems_as_queries = elems_as_queries\n", + " if self.learnable_queries:\n", + " self.queries = nn.Parameter(torch.zeros(1, 1, embedding_dim), requires_grad=True)\n", + " trunc_normal_(self.queries, std=0.02)\n", + "\n", + " self.save_attention = False\n", + " self.outdir = \"\"\n", + "\n", + " def forward(self, x, mask, initial_embedding):\n", + " mask_ = mask.unsqueeze(-1)\n", + " x = self.norm0(x * mask_)\n", "\n", - " x = x + mha_out\n", - " x = self.norm0(x)\n", - " x = x + self.seq(x)\n", - " x = self.norm1(x)\n", + " q = x\n", + " if self.learnable_queries:\n", + " q = self.queries.expand(*x.shape) * mask_\n", + " elif self.elems_as_queries:\n", + " q = initial_embedding * mask_\n", + "\n", + " key_padding_mask = None\n", + " if self.attention_type == \"math\":\n", + " key_padding_mask = ~mask\n", + "\n", + " # default path, for FlashAttn/Math backend\n", + " if self.enable_ctx_manager:\n", + " with sdpa_kernel(self.attn_params[self.attention_type]):\n", + " mha_out = self.mha(q, x, x)[0]\n", + "\n", + " # path for ONNX export\n", + " else:\n", + " mha_out = self.mha(q, x, x)[0]\n", + "\n", + " mha_out = mha_out * mask_\n", + "\n", + " mha_out = x + mha_out\n", + " x = self.norm1(mha_out)\n", + " x = mha_out + self.seq(x)\n", " x = self.dropout(x)\n", - " x = x * mask.unsqueeze(-1)\n", + " x = x * mask_\n", " return x\n", "\n", + "\n", "class SimpleMLPF(nn.Module):\n", " def __init__(\n", " self,\n", @@ -228,114 +341,215 @@ " dropout_ff=0.0,\n", " activation=\"elu\",\n", " layernorm=True,\n", + " conv_type=\"attention\",\n", + " input_encoding=\"joint\",\n", + " pt_mode=\"linear\",\n", + " eta_mode=\"linear\",\n", + " sin_phi_mode=\"linear\",\n", + " cos_phi_mode=\"linear\",\n", + " energy_mode=\"linear\",\n", " # element types which actually exist in the dataset\n", " elemtypes_nonzero=[1, 4, 5, 6, 8, 9, 10, 11],\n", + " # should the conv layer outputs be concatted (concat) or take the last (last)\n", + " learned_representation_mode=\"last\",\n", + " # gnn-lsh specific parameters\n", + " bin_size=640,\n", + " max_num_bins=200,\n", + " distance_dim=128,\n", + " num_node_messages=2,\n", + " ffn_dist_hidden_dim=128,\n", + " ffn_dist_num_layers=2,\n", " # self-attention specific parameters\n", " num_heads=16,\n", " head_dim=16,\n", + " attention_type=\"flash\",\n", " dropout_conv_reg_mha=0.0,\n", " dropout_conv_reg_ff=0.0,\n", " dropout_conv_id_mha=0.0,\n", " dropout_conv_id_ff=0.0,\n", + " use_pre_layernorm=False,\n", " ):\n", - " super().__init__()\n", + " super(SimpleMLPF, self).__init__()\n", + "\n", + " self.conv_type = conv_type\n", "\n", " self.act = get_activation(activation)\n", "\n", + " self.learned_representation_mode = learned_representation_mode\n", + "\n", + " self.input_encoding = input_encoding\n", + "\n", " self.input_dim = input_dim\n", " self.num_convs = num_convs\n", "\n", + " self.bin_size = bin_size\n", " self.elemtypes_nonzero = elemtypes_nonzero\n", "\n", - " embedding_dim = num_heads * head_dim\n", - " width = num_heads * head_dim\n", + " self.use_pre_layernorm = use_pre_layernorm\n", + "\n", + " if self.conv_type == \"attention\":\n", + " embedding_dim = num_heads * head_dim\n", + " width = num_heads * head_dim\n", "\n", " # embedding of the inputs\n", - " self.nn0_id = nn.ModuleList()\n", - " for ielem in range(len(self.elemtypes_nonzero)):\n", - " self.nn0_id.append(ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff))\n", - " self.nn0_reg = nn.ModuleList()\n", - " for ielem in range(len(self.elemtypes_nonzero)):\n", - " self.nn0_reg.append(ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff))\n", - " \n", - " self.conv_id = nn.ModuleList()\n", - " self.conv_reg = nn.ModuleList()\n", - "\n", - " for i in range(num_convs):\n", - " self.conv_id.append(\n", - " SimpleSelfAttentionLayer(\n", - " activation=activation,\n", - " embedding_dim=embedding_dim,\n", - " num_heads=num_heads,\n", - " width=width,\n", - " dropout_mha=dropout_conv_id_mha,\n", - " dropout_ff=dropout_conv_id_ff,\n", - " )\n", - " )\n", - " self.conv_reg.append(\n", - " SimpleSelfAttentionLayer(\n", - " activation=activation,\n", - " embedding_dim=embedding_dim,\n", - " num_heads=num_heads,\n", - " width=width,\n", - " dropout_mha=dropout_conv_reg_mha,\n", - " dropout_ff=dropout_conv_reg_ff,\n", - " )\n", - " )\n", - "\n", - " decoding_dim = self.input_dim + embedding_dim\n", + " if self.num_convs != 0:\n", + " if self.input_encoding == \"joint\":\n", + " self.nn0_id = ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff)\n", + " self.nn0_reg = ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff)\n", + " elif self.input_encoding == \"split\":\n", + " self.nn0_id = nn.ModuleList()\n", + " for ielem in range(len(self.elemtypes_nonzero)):\n", + " self.nn0_id.append(ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff))\n", + " self.nn0_reg = nn.ModuleList()\n", + " for ielem in range(len(self.elemtypes_nonzero)):\n", + " self.nn0_reg.append(ffn(self.input_dim, embedding_dim, width, self.act, dropout_ff))\n", + "\n", + " if self.conv_type == \"attention\":\n", + " self.conv_id = nn.ModuleList()\n", + " self.conv_reg = nn.ModuleList()\n", + "\n", + " for i in range(self.num_convs):\n", + " lastlayer = i == self.num_convs - 1\n", + " self.conv_id.append(\n", + " SimplePreLnSelfAttentionLayer(\n", + " name=\"conv_id_{}\".format(i),\n", + " activation=activation,\n", + " embedding_dim=embedding_dim,\n", + " num_heads=num_heads,\n", + " width=width,\n", + " dropout_mha=dropout_conv_id_mha,\n", + " dropout_ff=dropout_conv_id_ff,\n", + " attention_type=attention_type,\n", + " elems_as_queries=lastlayer,\n", + " # learnable_queries=lastlayer,\n", + " )\n", + " )\n", + " self.conv_reg.append(\n", + " SimplePreLnSelfAttentionLayer(\n", + " name=\"conv_reg_{}\".format(i),\n", + " activation=activation,\n", + " embedding_dim=embedding_dim,\n", + " num_heads=num_heads,\n", + " width=width,\n", + " dropout_mha=dropout_conv_reg_mha,\n", + " dropout_ff=dropout_conv_reg_ff,\n", + " attention_type=attention_type,\n", + " elems_as_queries=lastlayer,\n", + " # learnable_queries=lastlayer,\n", + " )\n", + " )\n", + " elif self.conv_type == \"gnn_lsh\":\n", + " self.conv_id = nn.ModuleList()\n", + " self.conv_reg = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " gnn_conf = {\n", + " \"inout_dim\": embedding_dim,\n", + " \"bin_size\": self.bin_size,\n", + " \"max_num_bins\": max_num_bins,\n", + " \"distance_dim\": distance_dim,\n", + " \"layernorm\": layernorm,\n", + " \"num_node_messages\": num_node_messages,\n", + " \"dropout\": dropout_ff,\n", + " \"ffn_dist_hidden_dim\": ffn_dist_hidden_dim,\n", + " \"ffn_dist_num_layers\": ffn_dist_num_layers,\n", + " }\n", + " self.conv_id.append(CombinedGraphLayer(**gnn_conf))\n", + " self.conv_reg.append(CombinedGraphLayer(**gnn_conf))\n", + "\n", + " if self.learned_representation_mode == \"concat\":\n", + " decoding_dim = self.num_convs * embedding_dim\n", + " elif self.learned_representation_mode == \"last\":\n", + " decoding_dim = embedding_dim\n", "\n", " # DNN that acts on the node level to predict the PID\n", " self.nn_binary_particle = ffn(decoding_dim, 2, width, self.act, dropout_ff)\n", " self.nn_pid = ffn(decoding_dim, num_classes, width, self.act, dropout_ff)\n", - " \n", + "\n", " # elementwise DNN for node momentum regression\n", - " embed_dim = decoding_dim + num_classes\n", - " self.nn_pt = RegressionOutput(\"linear\", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", - " self.nn_eta = RegressionOutput(\"linear\", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", - " self.nn_sin_phi = RegressionOutput(\"linear\", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", - " self.nn_cos_phi = RegressionOutput(\"linear\", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", - " self.nn_energy = RegressionOutput(\"linear\", embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", + " embed_dim = decoding_dim\n", + " self.nn_pt = RegressionOutput(pt_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", + " self.nn_eta = RegressionOutput(eta_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", + " self.nn_sin_phi = RegressionOutput(sin_phi_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", + " self.nn_cos_phi = RegressionOutput(cos_phi_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", + " self.nn_energy = RegressionOutput(energy_mode, embed_dim, width, self.act, dropout_ff, self.elemtypes_nonzero)\n", + "\n", + " if self.use_pre_layernorm: # add final norm after last attention block as per https://arxiv.org/abs/2002.04745\n", + " self.final_norm_id = torch.nn.LayerNorm(decoding_dim)\n", + " self.final_norm_reg = torch.nn.LayerNorm(embed_dim)\n", "\n", " # @torch.compile\n", " def forward(self, X_features, mask):\n", " Xfeat_normed = X_features\n", + " mask = mask.bool()\n", "\n", " embeddings_id, embeddings_reg = [], []\n", - " \n", - " embedding_id = torch.stack([nn0(Xfeat_normed) for nn0 in self.nn0_id], axis=-1)\n", - " elemtype_mask = torch.cat([X_features[..., 0:1] == elemtype for elemtype in self.elemtypes_nonzero], axis=-1)\n", - " embedding_id = torch.sum(embedding_id * elemtype_mask.unsqueeze(-2), axis=-1)\n", - "\n", - " embedding_reg = torch.stack([nn0(Xfeat_normed) for nn0 in self.nn0_reg], axis=-1)\n", - " elemtype_mask = torch.cat([X_features[..., 0:1] == elemtype for elemtype in self.elemtypes_nonzero], axis=-1)\n", - " embedding_reg = torch.sum(embedding_reg * elemtype_mask.unsqueeze(-2), axis=-1)\n", - "\n", - "\n", - " for num, conv in enumerate(self.conv_id):\n", - " conv_input = embedding_id if num == 0 else embeddings_id[-1]\n", - " out_padded = conv(conv_input, mask)\n", - " embeddings_id.append(out_padded)\n", - " for num, conv in enumerate(self.conv_reg):\n", - " conv_input = embedding_reg if num == 0 else embeddings_reg[-1]\n", - " out_padded = conv(conv_input, mask)\n", - " embeddings_reg.append(out_padded)\n", - "\n", - " final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1)\n", - " preds_id = self.nn_id(final_embedding_id)\n", - "\n", - " final_embedding_id = torch.cat([Xfeat_normed] + [embeddings_id[-1]], axis=-1)\n", - " final_embedding_reg = torch.cat([Xfeat_normed] + [embeddings_reg[-1]] + [preds_id], axis=-1)\n", + " if self.num_convs != 0:\n", + " if self.input_encoding == \"joint\":\n", + " embedding_id = self.nn0_id(Xfeat_normed)\n", + " embedding_reg = self.nn0_reg(Xfeat_normed)\n", + " elif self.input_encoding == \"split\":\n", + " embedding_id = torch.stack([nn0(Xfeat_normed) for nn0 in self.nn0_id], axis=-1)\n", + " elemtype_mask = torch.cat([X_features[..., 0:1] == elemtype for elemtype in self.elemtypes_nonzero], axis=-1)\n", + " embedding_id = torch.sum(embedding_id * elemtype_mask.unsqueeze(-2), axis=-1)\n", + " \n", + " embedding_reg = torch.stack([nn0(Xfeat_normed) for nn0 in self.nn0_reg], axis=-1)\n", + " elemtype_mask = torch.cat([X_features[..., 0:1] == elemtype for elemtype in self.elemtypes_nonzero], axis=-1)\n", + " embedding_reg = torch.sum(embedding_reg * elemtype_mask.unsqueeze(-2), axis=-1)\n", + " \n", + " for num, conv in enumerate(self.conv_id):\n", + " conv_input = embedding_id if num == 0 else embeddings_id[-1]\n", + " out_padded = conv(conv_input, mask, embedding_id)\n", + " embeddings_id.append(out_padded)\n", + " for num, conv in enumerate(self.conv_reg):\n", + " conv_input = embedding_reg if num == 0 else embeddings_reg[-1]\n", + " out_padded = conv(conv_input, mask, embedding_reg)\n", + " embeddings_reg.append(out_padded)\n", + "\n", + " # id input\n", + " if self.learned_representation_mode == \"concat\":\n", + " final_embedding_id = torch.cat(embeddings_id, axis=-1)\n", + " elif self.learned_representation_mode == \"last\":\n", + " final_embedding_id = torch.cat([embeddings_id[-1]], axis=-1)\n", + "\n", + " if self.use_pre_layernorm:\n", + " final_embedding_id = self.final_norm_id(final_embedding_id)\n", + "\n", + " preds_binary_particle = self.nn_binary_particle(final_embedding_id)\n", + " preds_pid = self.nn_pid(final_embedding_id)\n", + "\n", + " # pred_charge = self.nn_charge(final_embedding_id)\n", + "\n", + " # regression input\n", + " if self.learned_representation_mode == \"concat\":\n", + " final_embedding_reg = torch.cat(embeddings_reg, axis=-1)\n", + " elif self.learned_representation_mode == \"last\":\n", + " final_embedding_reg = torch.cat([embeddings_reg[-1]], axis=-1)\n", + "\n", + " # if self.use_pre_layernorm:\n", + " final_embedding_reg = self.final_norm_reg(final_embedding_reg)\n", "\n", " # The PFElement feature order in X_features defined in fcc/postprocessing.py\n", " preds_pt = self.nn_pt(X_features, final_embedding_reg, X_features[..., 1:2])\n", " preds_eta = self.nn_eta(X_features, final_embedding_reg, X_features[..., 2:3])\n", " preds_sin_phi = self.nn_sin_phi(X_features, final_embedding_reg, X_features[..., 3:4])\n", " preds_cos_phi = self.nn_cos_phi(X_features, final_embedding_reg, X_features[..., 4:5])\n", - " preds_energy = self.nn_energy(X_features, final_embedding_reg, X_features[..., 5:6])\n", - " preds_momentum = torch.cat([preds_pt, preds_eta, preds_sin_phi, preds_cos_phi, preds_energy], axis=-1)\n", "\n", - " return preds_id, preds_momentum" + " # ensure created particle has positive mass^2 by computing energy from pt and adding a positive-only correction\n", + " pt_real = torch.exp(preds_pt.detach()) * X_features[..., 1:2]\n", + " # sinh is not supported by ONNX so use exp instead of pz_real = pt_real * torch.sinh(preds_eta.detach())\n", + " detached_preds_eta = preds_eta.detach()\n", + " pz_real = pt_real * (torch.exp(detached_preds_eta) - torch.exp(-detached_preds_eta)) / 2\n", + " \n", + " e_real = torch.log(torch.sqrt(pt_real**2 + pz_real**2) / X_features[..., 5:6])\n", + " # the regular torch indexing of the mask results in changed tensor shapes in the ONNX model\n", + " # so we use torch.Tensor.masked_scatter_() instead of e_real[~mask] = 0\n", + " e_real.masked_scatter_(~mask.unsqueeze(-1), torch.zeros(size=(0, torch.sum(~mask))))\n", + " e_real[torch.isinf(e_real)] = 0\n", + " e_real[torch.isnan(e_real)] = 0\n", + " preds_energy = e_real + torch.nn.functional.relu(self.nn_energy(X_features, final_embedding_reg, X_features[..., 5:6]))\n", + " preds_momentum = torch.cat([preds_pt, preds_eta, preds_sin_phi, preds_cos_phi, preds_energy], axis=-1)\n", + " # assert list(preds_momentum.size())[-1] == 5, list(preds_momentum.size())\n", + " return preds_binary_particle, preds_pid, preds_momentum\n" ] }, { @@ -345,26 +559,12 @@ "metadata": {}, "outputs": [], "source": [ - "model_simple = SimpleMLPF(\n", - " input_dim=model_kwargs[\"input_dim\"],\n", - " num_classes=model_kwargs[\"num_classes\"],\n", - " embedding_dim=model_kwargs[\"num_heads\"]*model_kwargs[\"head_dim\"],\n", - " width=model_kwargs[\"num_heads\"]*model_kwargs[\"head_dim\"],\n", - " num_convs=model_kwargs[\"num_convs\"],\n", - " dropout_ff=model_kwargs[\"dropout_ff\"],\n", - " activation=model_kwargs[\"activation\"],\n", - " layernorm=True,\n", - " # element types which actually exist in the dataset\n", - " elemtypes_nonzero=model_kwargs[\"elemtypes_nonzero\"],\n", - " # self-attention specific parameters\n", - " num_heads=model_kwargs[\"num_heads\"],\n", - " head_dim=model_kwargs[\"head_dim\"],\n", - " dropout_conv_reg_mha=model_kwargs[\"dropout_conv_reg_mha\"],\n", - " dropout_conv_reg_ff=model_kwargs[\"dropout_conv_reg_ff\"],\n", - " dropout_conv_id_mha=model_kwargs[\"dropout_conv_id_mha\"],\n", - " dropout_conv_id_ff=model_kwargs[\"dropout_conv_id_ff\"],\n", - ")\n", - "model_simple.eval();" + "model_simple = SimpleMLPF(**model_kwargs)\n", + "model_simple.eval()\n", + "\n", + "#disable attention context manager (disable flash attention)\n", + "for conv in model_simple.conv_id + model_simple.conv_reg:\n", + " conv.enable_ctx_manager = False" ] }, { @@ -374,11 +574,10 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "model_simple.load_state_dict(model_state[\"model_state_dict\"])\n", "\n", "dummy_features = torch.randn(1, 256, model_kwargs[\"input_dim\"]).float()\n", - "dummy_mask = torch.randn(1, 256).float()\n", + "dummy_mask = torch.randn(1, 256).float() # < 0.9\n", "\n", "torch.onnx.export(\n", " model_simple,\n", @@ -389,10 +588,11 @@ " input_names=[\n", " \"Xfeat_normed\", \"mask\",\n", " ],\n", - " output_names=[\"id\", \"momentum\"],\n", + " output_names=[\"bid\", \"id\", \"momentum\"],\n", " dynamic_axes={\n", " \"Xfeat_normed\": {0: \"num_batch\", 1: \"num_elements\"},\n", " \"mask\": {0: \"num_batch\", 1: \"num_elements\"},\n", + " \"bid\": {0: \"num_batch\", 1: \"num_elements\"},\n", " \"id\": {0: \"num_batch\", 1: \"num_elements\"},\n", " \"momentum\": {0: \"num_batch\", 1: \"num_elements\"},\n", " },\n", @@ -461,18 +661,33 @@ " input_names=[\n", " \"Xfeat_normed\", \"mask\",\n", " ],\n", - " output_names=[\"id\", \"momentum\"],\n", + " output_names=[\"bid\", \"id\", \"momentum\"],\n", " dynamic_axes={\n", " \"Xfeat_normed\": {0: \"num_batch\", 1: \"num_elements\"},\n", " \"mask\": {0: \"num_batch\", 1: \"num_elements\"},\n", + " \"bid\": {0: \"num_batch\", 1: \"num_elements\"},\n", " \"id\": {0: \"num_batch\", 1: \"num_elements\"},\n", " \"momentum\": {0: \"num_batch\", 1: \"num_elements\"},\n", " },\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08e50424-5a61-4c9c-8f5b-6df132dd1768", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Available ONNX runtime providers:\", rt.get_available_providers())\n", "sess_options = rt.SessionOptions()\n", - "onnx_sess_unfused = rt.InferenceSession(\"test_fp32_unfused.onnx\", sess_options, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n", - "onnx_sess_fused = rt.InferenceSession(\"test_fp32_fused.onnx\", sess_options, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])" + "sess_options.intra_op_num_threads = 32 # need to explicitly set this to get rid of onnxruntime error\n", + "\n", + "# removed CUDAExecutionProvider because my HPC system has a cudnn version that's too old\n", + "# onnx_sess_unfused = rt.InferenceSession(\"test_fp32_unfused.onnx\", sess_options, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n", + "# onnx_sess_fused = rt.InferenceSession(\"test_fp32_fused.onnx\", sess_options, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n", + "onnx_sess_unfused = rt.InferenceSession(\"test_fp32_unfused.onnx\", sess_options, providers=[\"CPUExecutionProvider\"])\n", + "onnx_sess_fused = rt.InferenceSession(\"test_fp32_fused.onnx\", sess_options, providers=[\"CPUExecutionProvider\"])" ] }, { @@ -521,7 +736,7 @@ "source": [ "builder = tfds.builder(dataset, data_dir=data_dir)\n", "ds = builder.as_data_source(split=\"test\")\n", - "max_events = 20\n", + "max_events = 10\n", "events_per_batch = 1\n", "inds = range(0, max_events, events_per_batch)\n", "\n", @@ -533,10 +748,10 @@ "model = model.to(torch_device)\n", "model_simple = model_simple.to(torch_device)\n", "\n", - "for ind in inds:\n", + "for ind in tqdm(inds):\n", " ds_elems = [ds[i] for i in range(ind,ind+events_per_batch)]\n", " X_features = [torch.tensor(elem[\"X\"]).to(torch.float32).to(torch_device) for elem in ds_elems]\n", - " y_targets = [torch.tensor(elem[\"ygen\"]).to(torch.float32).to(torch_device) for elem in ds_elems]\n", + " y_targets = [torch.tensor(elem[\"ytarget\"]).to(torch.float32).to(torch_device) for elem in ds_elems]\n", "\n", " #batch the data into [batch_size, num_elems, num_features]\n", " X_features_padded = pad_sequence(X_features, batch_first=True).contiguous()\n", @@ -548,10 +763,10 @@ " with torch.no_grad():\n", " print(\"running base model\")\n", " pred = model(X_features_padded, mask)\n", - " pred = (pred[0].cpu(), pred[1].cpu())\n", + " pred = tuple(pred[x].cpu() for x in range(len(pred)))\n", " print(\"running simplified model\")\n", " pred_simple = model_simple(X_features_padded, mask)\n", - " pred_simple = (pred_simple[0].cpu(), pred_simple[1].cpu())\n", + " pred_simple = tuple(pred_simple[x].cpu() for x in range(len(pred_simple)))\n", "\n", " j0 = particles_to_jets(pred, mask.cpu())\n", " jets_mlpf.append(j0)\n", @@ -567,7 +782,8 @@ " print(\"diffs: {:.8f} {:.8f}\".format(*diffs))\n", "\n", " print(\"running ONNX unfused model\")\n", - " pred_onnx_unfused = onnx_sess_unfused.run(None, {\"Xfeat_normed\": X_features_padded.cpu().numpy(), \"mask\": mask_f.cpu().numpy()})\n", + " # pred_onnx_unfused = onnx_sess_unfused.run(None, {\"Xfeat_normed\": X_features_padded.cpu().numpy(), \"mask\": mask_f.cpu().numpy()})\n", + " pred_onnx_unfused = onnx_sess_unfused.run([\"bid\", \"id\", \"momentum\"], {\"Xfeat_normed\": X_features_padded.cpu().numpy(), \"mask\": mask_f.cpu().numpy()})\n", " pred_onnx_unfused = tuple(torch.tensor(p) for p in pred_onnx_unfused)\n", " j2 = particles_to_jets(pred_onnx_unfused, mask.cpu())\n", " jets_onnx_unfused.append(j2)\n", @@ -577,7 +793,8 @@ " torch.testing.assert_close(pred[1], pred_onnx_unfused[1], atol=0.01, rtol=0.01)\n", " \n", " print(\"running ONNX fused model\")\n", - " pred_onnx_fused = onnx_sess_fused.run(None, {\"Xfeat_normed\": X_features_padded.cpu().numpy(), \"mask\": mask_f.cpu().numpy()})\n", + " # pred_onnx_fused = onnx_sess_fused.run(None, {\"Xfeat_normed\": X_features_padded.cpu().numpy(), \"mask\": mask_f.cpu().numpy()})\n", + " pred_onnx_fused = onnx_sess_fused.run([\"bid\", \"id\", \"momentum\"], {\"Xfeat_normed\": X_features_padded.cpu().numpy(), \"mask\": mask_f.cpu().numpy()})\n", " pred_onnx_fused = tuple(torch.tensor(p) for p in pred_onnx_fused)\n", " j3 = particles_to_jets(pred_onnx_fused, mask.cpu())\n", " jets_onnx_fused.append(j3)\n", @@ -641,6 +858,33 @@ "plt.colorbar()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f501b02f-ca20-41a9-ae94-e6857854e5ab", + "metadata": {}, + "outputs": [], + "source": [ + "def remove_events_with_different_numbers_of_jets(jets1, jets2):\n", + " jets1_cleaned = []\n", + " jets2_cleaned = []\n", + " for x, y in zip(jets1, jets2):\n", + " if x.shape == y.shape:\n", + " jets1_cleaned.append(x)\n", + " jets2_cleaned.append(y)\n", + " else:\n", + " print(\"removing jets\")\n", + " return jets1_cleaned, jets2_cleaned" + ] + }, + { + "cell_type": "raw", + "id": "f1f38b1e-2f2d-490f-9959-bcd2ea7d0b49", + "metadata": {}, + "source": [ + "jets_mlpf_simple, jets_onnx_unfused = remove_events_with_different_numbers_of_jets(jets_mlpf_simple, jets_onnx_unfused)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -660,6 +904,14 @@ "plt.colorbar()" ] }, + { + "cell_type": "raw", + "id": "860cbca5-9219-4d5f-bdd7-214d1b35fb86", + "metadata": {}, + "source": [ + "jets_onnx_fused, jets_onnx_unfused = remove_events_with_different_numbers_of_jets(jets_onnx_fused, jets_onnx_unfused)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -705,7 +957,13 @@ "cell_type": "code", "execution_count": null, "id": "9a5a354a-ea10-4f47-9577-a6495130d5a1", - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [], "source": [ "b = np.linspace(10,100,21)\n", @@ -719,13 +977,21 @@ "plt.plot(h0.axes[0].centers, (h3/h0).values(), marker=\"o\", ms=2.0, lw=1.0)\n", "plt.ylim(0.8,1.2)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1692425-7157-41f3-bbbe-c48c6d8114f7", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "mlpf", "language": "python", - "name": "python3" + "name": "mlpf" }, "language_info": { "codemirror_mode": { @@ -737,7 +1003,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/requirements.txt b/requirements.txt index 6cd2373fc..613d6b837 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,6 @@ matplotlib mlcroissant mplhep networkx -nevergrad notebook numba numpy @@ -26,7 +25,7 @@ plotly pre-commit protobuf pyarrow -ray[train,tune] +ray[tune] scikit-learn scikit-optimize scipy