Skip to content

Commit

Permalink
update lm seminar
Browse files Browse the repository at this point in the history
  • Loading branch information
mannefedov committed Nov 30, 2023
1 parent d3a4bc1 commit 6e38e01
Showing 1 changed file with 132 additions and 35 deletions.
167 changes: 132 additions & 35 deletions notebooks/lm_intro/Language_model_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,16 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"from scipy.sparse import lil_matrix"
"from scipy.sparse import lil_matrix, csr_matrix, csc_matrix"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -773,12 +773,14 @@
" word1, word2 = ngram.split()\n",
" # на пересечение двух слов ставим вероятность встретить второе после первого\n",
" matrix_dvach[word2id_dvach[word1], word2id_dvach[word2]] = (bigrams_dvach[ngram]/\n",
" unigrams_dvach[word1])"
" unigrams_dvach[word1])\n",
" \n",
"# matrix_dvach = csc_matrix(matrix_dvach)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -794,16 +796,11 @@
"for ngram in bigrams_news:\n",
" word1, word2 = ngram.split()\n",
" matrix_news[word2id_news[word1], word2id_news[word2]] = (bigrams_news[ngram]/\n",
" unigrams_news[word1])"
" unigrams_news[word1])\n",
" \n",
"# matrix_news = csc_matrix(matrix_news)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -813,7 +810,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -824,12 +821,8 @@
" for i in range(n):\n",
" \n",
" chosen = np.random.choice(matrix.shape[1], p=matrix[current_idx].toarray()[0])\n",
"# просто выбирать наиболее вероятное продолжение не получится\n",
"# можете попробовать раскоментировать следующую строчку и посмотреть что получается\n",
"# в современных языковых моделях есть специальный параметр, который\n",
"# позволяет регулировать разнообразность/случайность генерации\n",
"# он называется температура, чем выше температура тем ближе будет к argmax\n",
"# чем меньше температура тем ближе к полностью рандомной генерации\n",
" # просто выбирать наиболее вероятное продолжение не получится\n",
" # можете попробовать раскоментировать следующую строчку и посмотреть что получается\n",
"# chosen = matrix[current_idx].argmax()\n",
" text.append(id2word[chosen])\n",
" \n",
Expand All @@ -842,20 +835,15 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 136,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"официально \n",
" сладкого то есть опыт работы ты потратил ебаную базу касперский ну если денег \n",
" и без смс \n",
" я представляю кто с нами так я с близкими людьми беседуюя с выходом \n",
" и унизились перед смертью и внушил себе в пользу только мс не успел смениться из-за бабы ― какие-то алгоритмы типы которые всего 2 играл в 70-м его месте как ты маг \n",
" меня в итоге получилось так всех этих успешных и навязывают стелсопрохождение там поштучно можно предположить то что пора бы не как поиграть не складируй посуду после травы так и годам \n",
" сосни ватан ты\n"
"в отправляется непотребство прочее и прочих аутистов \n",
" а утилит всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех всех\n"
]
}
],
Expand All @@ -865,25 +853,134 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 137,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"техническая помощь на процессе сыграло и интерполом по оценкам численность населения санкций общая стоимость пострадавших пока хорошее \n",
" в ее руководителей трех до того какврачи поставили диагноз гриппа \n",
" наваз шариф будет введен визовый пограничный и ряд вопросов отметил что бои по этой суммы выделенной конгрессом ваад россии владимир путин встретится с сибирскими морозами и выкуп за собой военную академию где говорится в россию поэтому на месте \n",
" один из них сказал в среду и хартум столица чечни обнаружены следы этого самолет требует единства сергей шойгу заявил что за собой в черных металлов и теллурия \n",
" это случилось землетрясение силой\n"
"в главной военной службы пожар в трех объединений минюста на тверской и именнона этом сообщил риа новости в москве в оон на след \n",
" по месту взрыва школы цигун \n",
" как отмечают что такая система ориентации в деле важным элементом комплексной системы обладающей достоинствами некоммерческих аналогов в другой экологический проект судебного иска против проводниц возбуждено уголовное дело рук может быть поставлено 20 процентов повышается производительность труда в среднем на своем заявлении президент аргентины \n",
" днемона пряталась опасаясь что в 225 евро \n",
" как утверждают в русском сегменте интернета в тот же этого можно сделать его людей среди сайтов на юге\n"
]
}
],
"source": [
"print(generate(matrix_news, id2word_news, word2id_news).replace('<end>', '\\n'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Beam Search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выше мы попробовали два способа выбирать предсказания на основе имеющихся вероятностей - 1) брать самое вероятное и 2) семплировать согласно распределению. К этому можно добавлять еще много других настроек и со многими мы еще поработаем, когда дойдем до больших языковых моделей. Сейчас давайте разберем еще один алгоритм, который можно применять для улучшения генерации. Он называется beam search (поиск лучом? лучевой поиск?).\n",
"\n",
"![](https://opennmt.net/OpenNMT/img/beam_search.png)\n",
"\n",
"Идея тут в том, чтобы на каждом шаге генерировать несколько вариантов продолжений, а затем несколько вариантов и для каждого из предыдуших продолжений. Таким образом, получается дерево генерации, где каждый вариант на следующем шаге ветвится на несколько других вариантов. Чтобы дерево не разрасталось слишком сильно в beam search есть параметр, которым задает максимальное количество вариантов на каждом шаге. Если вариантов больше, то часть из них удаляется и больше не продолжается. Чтобы отранжировать варианты, для каждого из них расчитывается общая вероятность (всей последовательности!) и выбираются самые вероятные. Обратите внимание, что на картинке на каждом из шагов не более 5 вариантов, а некоторые не доживают до последнего шага.\n",
"\n",
"При простой генерации по одному слову есть вероятность сделать неправильный выбор и закрыть возможность для других хороших продолжений. А beam search позволяет рассматривать сразу несколько вариантов и вероятность пойти не туда, значительно снижается. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Давайте напишем функцию для генерации с помощью beam search"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
"# сделаем класс чтобы хранить каждый из лучей\n",
"class Beam:\n",
" def __init__(self, sequence: list, score: float):\n",
" self.sequence: list = sequence\n",
" self.score: float = score "
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def generate_with_beam_search(matrix, id2word, word2id, n=100, max_beams=5, start='<start>'):\n",
" # изначально у нас один луч с заданным началом (start по дефолту)\n",
" initial_node = Beam(sequence=[start], score=np.log1p(0))\n",
" beams = [initial_node]\n",
" \n",
" for i in range(n):\n",
" # делаем n шагов генерации\n",
" new_beams = []\n",
" # на каждом шаге продолжаем каждый из имеющихся лучей\n",
" for beam in beams:\n",
" # лучи которые уже закончены не продолжаем (но и не удаляем)\n",
" if beam.sequence[-1] == '<end>':\n",
" new_beams.append(beam)\n",
" continue\n",
" \n",
" # наша языковая модель предсказывает на основе предыдущего слова\n",
" # достанем его из beam.sequence\n",
" last_id = word2id[beam.sequence[-1]]\n",
" \n",
" # посмотрим вероятности продолжений для предыдущего слова\n",
" probas = matrix[last_id].toarray()[0]\n",
" \n",
" # возьмем топ самых вероятных продолжений\n",
" top_idxs = probas.argsort()[:-(max_beams+1):-1]\n",
" for top_id in top_idxs:\n",
" # иногда вероятности будут нулевые, такое не добавляем\n",
" if not probas[top_id]:\n",
" break\n",
" \n",
" # создадим новый луч на основе текущего и варианта продолжения\n",
" new_sequence = beam.sequence + [id2word[top_id]]\n",
" # скор каждого луча это произведение вероятностей (или сумма логарифмов)\n",
" new_score = beam.score + np.log1p(probas[top_id])\n",
" new_beam = Beam(sequence=new_sequence, score=new_score)\n",
" new_beams.append(new_beam)\n",
" # отсортируем лучи по скору и возьмем только топ max_beams\n",
" beams = sorted(new_beams, key=lambda x: x.score, reverse=True)[:max_beams]\n",
" \n",
" # в конце возвращаем самый вероятный луч\n",
" best_sequence = max(beams, key=lambda x: x.score).sequence\n",
"\n",
" \n",
" return ' '.join(best_sequence)"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"куда хинштейн предъявил вашингтону ультиматум предъявленный жителям а также отметил что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что в связи с 1 января 2000 года в том что\n"
]
}
],
"source": [
"print(generate_with_beam_search(matrix_news, id2word_news, word2id_news, start='куда'))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1232,7 +1329,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.8.14"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 6e38e01

Please sign in to comment.