From da99cf61a94628d3f0f5a4c6886efdc682dfac9b Mon Sep 17 00:00:00 2001 From: Josh Abelman Date: Fri, 16 Aug 2019 23:27:46 -0400 Subject: [PATCH] Fixed indentation bug in sequential generation --- .../bert-babble-checkpoint.ipynb | 835 ++++++++++++++++++ bert-babble.ipynb | 2 +- 2 files changed, 836 insertions(+), 1 deletion(-) create mode 100644 .ipynb_checkpoints/bert-babble-checkpoint.ipynb diff --git a/.ipynb_checkpoints/bert-babble-checkpoint.ipynb b/.ipynb_checkpoints/bert-babble-checkpoint.ipynb new file mode 100644 index 0000000..4d666a7 --- /dev/null +++ b/.ipynb_checkpoints/bert-babble-checkpoint.ipynb @@ -0,0 +1,835 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import torch\n", + "from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 231508/231508 [00:00<00:00, 2738215.23B/s]\n" + ] + } + ], + "source": [ + "# Load pre-trained model (weights)\n", + "model_version = 'bert-base-uncased'\n", + "model = BertForMaskedLM.from_pretrained(model_version)\n", + "model.eval()\n", + "cuda = torch.cuda.is_available()\n", + "if cuda:\n", + " model = model.cuda(0)\n", + "\n", + "# Load pre-trained model tokenizer (vocabulary)\n", + "tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=model_version.endswith(\"uncased\"))\n", + "\n", + "def tokenize_batch(batch):\n", + " return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]\n", + "\n", + "def untokenize_batch(batch):\n", + " return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]\n", + "\n", + "def detokenize(sent):\n", + " \"\"\" Roughly detokenizes (mainly undoes wordpiece) \"\"\"\n", + " new_sent = []\n", + " for i, tok in enumerate(sent):\n", + " if tok.startswith(\"##\"):\n", + " new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]\n", + " else:\n", + " new_sent.append(tok)\n", + " return new_sent\n", + "\n", + "CLS = '[CLS]'\n", + "SEP = '[SEP]'\n", + "MASK = '[MASK]'\n", + "mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]\n", + "sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]\n", + "cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):\n", + " \"\"\" Generate a word from from out[gen_idx]\n", + " \n", + " args:\n", + " - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size\n", + " - gen_idx (int): location for which to generate for\n", + " - top_k (int): if >0, only sample from the top k most probable words\n", + " - sample (Bool): if True, sample from full distribution. Overridden by top_k \n", + " \"\"\"\n", + " logits = out[:, gen_idx]\n", + " if temperature is not None:\n", + " logits = logits / temperature\n", + " if top_k > 0:\n", + " kth_vals, kth_idx = logits.topk(top_k, dim=-1)\n", + " dist = torch.distributions.categorical.Categorical(logits=kth_vals)\n", + " idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)\n", + " elif sample:\n", + " dist = torch.distributions.categorical.Categorical(logits=logits)\n", + " idx = dist.sample().squeeze(-1)\n", + " else:\n", + " idx = torch.argmax(logits, dim=-1)\n", + " return idx.tolist() if return_list else idx" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Generation modes as functions\n", + "import math\n", + "import time\n", + "\n", + "def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):\n", + " \"\"\" Get initial sentence by padding seed_text with either masks or random words to max_len \"\"\"\n", + " batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]\n", + " #if rand_init:\n", + " # for ii in range(max_len):\n", + " # init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))\n", + " \n", + " return tokenize_batch(batch)\n", + "\n", + "def parallel_sequential_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,\n", + " cuda=False, print_every=10, verbose=True):\n", + " \"\"\" Generate for one random position at a timestep\n", + " \n", + " args:\n", + " - burnin: during burn-in period, sample from full distribution; afterwards take argmax\n", + " \"\"\"\n", + " seed_len = len(seed_text)\n", + " batch = get_init_text(seed_text, max_len, batch_size)\n", + " \n", + " for ii in range(max_iter):\n", + " kk = np.random.randint(0, max_len)\n", + " for jj in range(batch_size):\n", + " batch[jj][seed_len+kk] = mask_id\n", + " inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)\n", + " out = model(inp)\n", + " topk = top_k if (ii >= burnin) else 0\n", + " idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))\n", + " for jj in range(batch_size):\n", + " batch[jj][seed_len+kk] = idxs[jj]\n", + " \n", + " if verbose and np.mod(ii+1, print_every) == 0:\n", + " for_print = tokenizer.convert_ids_to_tokens(batch[0])\n", + " for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]\n", + " print(\"iter\", ii+1, \" \".join(for_print))\n", + " \n", + " return untokenize_batch(batch)\n", + "\n", + "def parallel_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, sample=True, \n", + " cuda=False, print_every=10, verbose=True):\n", + " \"\"\" Generate for all positions at a time step \"\"\"\n", + " seed_len = len(seed_text)\n", + " batch = get_init_text(seed_text, max_len, batch_size)\n", + " \n", + " for ii in range(max_iter):\n", + " inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)\n", + " out = model(inp)\n", + " for kk in range(max_len):\n", + " idxs = generate_step(out, gen_idx=seed_len+kk, top_k=top_k, temperature=temperature, sample=sample)\n", + " for jj in range(batch_size):\n", + " batch[jj][seed_len+kk] = idxs[jj]\n", + " \n", + " if verbose and np.mod(ii, print_every) == 0:\n", + " print(\"iter\", ii+1, \" \".join(tokenizer.convert_ids_to_tokens(batch[0])))\n", + " \n", + " return untokenize_batch(batch)\n", + " \n", + "def sequential_generation(seed_text, batch_size=2, max_len=15, leed_out_len=15, \n", + " top_k=0, temperature=None, sample=True, cuda=False):\n", + " \"\"\" Generate one word at a time, in L->R order \"\"\"\n", + " seed_len = len(seed_text)\n", + " batch = get_init_text(seed_text, max_len, batch_size)\n", + " batch = batch.cuda() if cuda else batch\n", + " \n", + " for ii in range(max_len):\n", + " inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]\n", + " inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)\n", + " out = model(inp)\n", + " idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, temperature=temperature, sample=sample)\n", + " for jj in range(batch_size):\n", + " batch[jj][seed_len+ii] = idxs[jj]\n", + " \n", + " return untokenize_batch(batch)\n", + "\n", + "\n", + "def generate(n_samples, seed_text=\"[CLS]\", batch_size=10, max_len=25, \n", + " sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,\n", + " cuda=False, print_every=1):\n", + " # main generation function to call\n", + " sentences = []\n", + " n_batches = math.ceil(n_samples / batch_size)\n", + " start_time = time.time()\n", + " for batch_n in range(n_batches):\n", + " batch = parallel_sequential_generation(seed_text, max_len=max_len, top_k=top_k,\n", + " temperature=temperature, burnin=burnin, max_iter=max_iter, \n", + " cuda=cuda, verbose=False)\n", + " \n", + " #batch = sequential_generation(seed_text, batch_size=20, max_len=max_len, top_k=top_k, temperature=temperature, leed_out_len=leed_out_len, sample=sample)\n", + " #batch = parallel_generation(seed_text, max_len=max_len, top_k=top_k, temperature=temperature, sample=sample, max_iter=max_iter)\n", + " \n", + " if (batch_n + 1) % print_every == 0:\n", + " print(\"Finished batch %d in %.3fs\" % (batch_n + 1, time.time() - start_time))\n", + " start_time = time.time()\n", + " \n", + " sentences += batch\n", + " return sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Utility functions\n", + "\n", + "def printer(sent, should_detokenize=True):\n", + " if should_detokenize:\n", + " sent = detokenize(sent)[1:-1]\n", + " print(\" \".join(sent))\n", + " \n", + "def read_sents(in_file, should_detokenize=False):\n", + " sents = [sent.strip().split() for sent in open(in_file).readlines()]\n", + " if should_detokenize:\n", + " sents = [detokenize(sent) for sent in sents]\n", + " return sents\n", + "\n", + "def write_sents(out_file, sents, should_detokenize=False):\n", + " with open(out_file, \"w\") as out_fh:\n", + " for sent in sents:\n", + " sent = detokenize(sent[1:-1]) if should_detokenize else sent\n", + " out_fh.write(\"%s\\n\" % \" \".join(sent))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished batch 1 in 47.214s\n", + "Finished batch 2 in 47.047s\n", + "Finished batch 3 in 47.064s\n", + "Finished batch 4 in 47.087s\n", + "Finished batch 5 in 47.094s\n", + "Finished batch 6 in 47.096s\n", + "Finished batch 7 in 47.094s\n", + "Finished batch 8 in 47.097s\n", + "Finished batch 9 in 47.093s\n", + "Finished batch 10 in 47.090s\n", + "Finished batch 11 in 47.087s\n", + "Finished batch 12 in 47.070s\n", + "Finished batch 13 in 47.070s\n", + "Finished batch 14 in 46.996s\n", + "Finished batch 15 in 46.999s\n", + "Finished batch 16 in 46.999s\n", + "Finished batch 17 in 46.994s\n", + "Finished batch 18 in 46.991s\n", + "Finished batch 19 in 46.978s\n", + "Finished batch 20 in 46.978s\n" + ] + } + ], + "source": [ + "n_samples = 1000\n", + "batch_size = 50\n", + "max_len = 40\n", + "top_k = 100\n", + "temperature = 0.7\n", + "\n", + "leed_out_len = 5 # max_len\n", + "burnin = 250\n", + "sample = True\n", + "max_iter = 500\n", + "\n", + "# Choose the prefix context\n", + "seed_text = \"[CLS]\".split()\n", + "\n", + "for temp in [1.0]:\n", + " bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,\n", + " sample=sample, top_k=top_k, temperature=temp, burnin=burnin, max_iter=max_iter,\n", + " cuda=True)\n", + " out_file = \"data/%s-len%d-burnin%d-topk%d-temp%.3f.txt\" % (model_version, max_len, burnin, top_k, temp)\n", + " write_sents(out_file, bert_sents, should_detokenize=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "in_file = \"data/generations-len20-burnin200-temp0.700.txt\"\n", + "bert_sents = read_sents(in_file, should_detokenize=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "in 2006 , the army announced plans to overhaul its schedule , notably reducing the march - october period , and shifting away from summer and winter exercises , which were annual , and the february - april period .\n", + "fuck , fuck , i kept it in that - \" fuck , fuck , fuck , fuck , fuck . \" * * * home , i mean , the house was a large part of my life .\n", + "my \" father \" ( matthew 15 : 17 - 26 ) ' s ornithological name ( \" e \" or \" a \" ) ; my \" mother \" , or my \" sister \" ;\n", + "london : macmillan . 1 june . thompson , thomas ( rev . 2003 ) . toronto ccp . volume 1 : the municipal elections in greater toronto , monday , 5 november 1962 . official party site .\n", + "b : b ( a + 1 ) = 3 and b : b ( a ) = = 2 c : a + ab c = 3 and are the values of all these variables defined as and .\n", + "a : barcelona , 1978 . el gimpy ( poet ) a : barcelona , 1982 . his books published posthumously include : la casa de historia catalan ( house of catalan history ) , impr .\n", + "this department offers graduates of \" marketing and commerce \" and communication and law \" ( communication and law department was created ) . english - language specialization was started after three years of undergraduate studies in this department .\n", + "captains : thomas potter ( 4 ) ( ws ) , john johnson , ( ws ) , ( fn ) status : decommissioned 6 / 02 / 2039 ( taken out of service ) ;\n", + "20 april 2007 . photographs , paintings , drawing , and sculpture . the springfield gallery of art . 264 pages . - ( 1999 ) . take it home , a celebration . the springfield gallery of art .\n", + "( also available as a free download ) bmx ( commodore 94 / 95 and later ) : discontinued from version 0 . 2 onwards , ( since 1 . 1 ) with additional support for operating systems .\n", + "[UNK] , [UNK] 。 」 [UNK] 。 ; [UNK] , [UNK] , [UNK] , [UNK] , , [UNK] [ 1 [UNK] ] ; [UNK] [ 1 + 1 + 1 + 1 + 1 + 1 = 1 ] ;\n", + "2 - h hospital - \" kuringai \" ( 2 - h ) acts as a collective refuge in extreme trauma areas , and provides health care and family support services for crime victims and accident victims .\n", + "porter , m . \" benjamin porter , deceased . \" porter , m . \" benjamin robert porter \" . \" america ' s patriot , benjamin robert porter \" . porter was the first land claims advocate .\n", + "( this belongs to you , not beyond you . ) faith in god and the all - knowing world guides his conduct throughout the year . every year , master garg visits the jama masjid again .\n", + "as long as his [ kids ] and his ' bopping around and comin ' in and jesus , jesus , jesus , this is him ( jesus , jesus ) ... and his [ kids ] .\n", + "20 . 5 % discounting factor ( top value : 20 . 5 % ) during the regression plot , a subset of associated parameters are placed . the highest associated parameter ( top value ) : - .\n", + "ex . 56 , 7 hob . 77 : 30 - 31 [ 1 cor . : ps . 37 hob . 14 : 38 ] ; pygm . 31 : 6 acad .\n", + "non - national scouts are represented internationally by najd which is comprises a commissioner , a scouter , a deputy commissioner , deputy chief scouter and two nurses . scout country is 100 % canadian .\n", + "the city , at that time , was looting . for this reason , a group of soldiers were stationed outside the royal palace where the negotiations were being held . many of the public houses had been looted .\n", + "\" the spirit of her father jesus christ has gone into jesus ' body , taking over his individual soul . \" but he is still able to , physically and spiritually , live in his own body .\n", + "the fortress units of the bavarian army served on the eastern front before being disbanded along with other formations , while the corps administered the garrison artillery which were subordinate to the army , navy , and air force .\n", + "all at once , in fact , are you realizing that everything is going to be as it was yesterday . the world is made up of where something can be on its own and where it must be .\n", + "morris , j . - ann ( \" margie \" morris ) ; alnwick , l . noblewomen and gentlemen , medieval , anglo - saxon , elizabethan ( pub . 1981 ) .\n", + "they started with a 7th place finish after 22 games and were relegated 6 may 2013 , however they had a poor season back despite having a good young squad which made them end up with a 8th place .\n", + "by composting the human waste ; by adding the seeds of sugarcane and that of water ( part 1 ) ; by adding the oil of sugarcane , halibut , and all their necessary uses ;\n", + "the \" family \" was continuously killing . as the \" family \" was fighting the kabaka , \" the family \" continued killing . despite it being almost all dead , the casualties were still high .\n", + "there is either an out - count in left field , if one is not on third , or a single on third base . the count is immediately changed so that the batter now has the ball back .\n", + "on this occasion , the act would now expressly state that \" g e f \" should not be a prefix for standard sign - language , instead being a substitute for \" s n s p \" .\n", + "it eats insects and other terrestrial animals ( fruit e . g . ) , and is an uncommon nocturnal species in hawaii . the small heads and back of adults are dark , lighter above and darker below .\n", + "sin yz ( y ) or sin iz ( a ) ( spanish : x : ' x ' ; y : ' es ' ; a : ' i / a ' ) are long vowels .\n", + "why not have me killed ? maybe they kept the children in a small , deserted building , so close to the plantation ' s main road that it had to be cleaned out . maybe you are right .\n", + "both new writers j . t . kenny first published in the irish press and john yeats who later wrote his novel the madwolf were both first published in the quarterly portion of the irish press .\n", + "you are the ones who are afraid ! \" come , come , come , out of sight ! - down , down , down ! meanwhile , people are coming down from the hill , from the sea .\n", + "the song ends with the following line : \" this night ' s wonderful stuff ... wonderfully good stuff ! \" stand - up comedian eugene wells performs a highly emotional sketch that revolves around his own life .\n", + "the \" report \" under article 8 and article 13 is a \" summary accompanying the main article \" or first part of a claim ; from 8 the \" case report \" and article 13 \" publication \" .\n", + "mr . fraker , william newton whitney , president of heyman ; george g . whitney , president , the west coal company ; mr . william whitney , president of the washington coal company nc . ;\n", + "in : g . m . mineson , mineralogical series , british museum , london . p 171 . kaplinsky , mineralogy part i and part ii , 22 - 24 , p22 .\n", + "indianapolis : william morgan publishing . edited by tom c . o ' sullivan . published in black and white , black ink , and photographs . volume 1 of indiana historical magazines . each volume has supplemental articles .\n", + "cambridge , ma , 1980 critical education : global perspectives . cambridge , ma , 1979 inside the new order a century of women : women using the new social workers ' profession : lessons from the early years .\n", + "retired through 2016 . listed from list of career milestones prior to the 1996 season , his hitting line had been to allow only 5 hits ( 1 , 2 , 3 , 3 ) per base hit .\n", + "on the mediterranean coasts of france and the british isles , she is usually recorded only as the lady living in the palace of the king . the names of her firstborn son are likely ralph and edward .\n", + "cox , anne marie ; fisherman , john ; foster , david p . \" the care and treatment of the elderly \" . medicine 12 . 2 : 53 - 69 . fisherman , john m ;\n", + "place the glock on the table . \" he did , and i looked around at the kids ( most are in class and none are doing their homework right now ) , and said , \" michael !\n", + "the last straw is a drama film by director joshua bell about sambo singer janelle roldan and friends working together , which according to bell , fell apart after \" a couple of decades \" .\n", + "algebra iii - algebra v algebra ii analysis , statistics , and higher statistics analysis , applied to statistics , and higher statistics symbols for terms identified with multiplicity are : ( 0 , 1 , ... ) .\n", + "set and , and and . for the first row , and , and and . for the second row , and . set and , and , and , and , and , and , and , and .\n", + "§ 19 . 8 . 6 of the 1957 indian relations act . § 13b . , 13c . and 13 . 8 . of the act of 1933 . section 3 ( preliminary investigatory ) .\n", + "( 8b . ) e a , e b , e c , e e d , e , e e , e e ; ( 10c , 8d . ) e l , e f ;\n", + "edited by f . w . clarke ( dublin school ) : : clarke , f . ( dublin school , 1921 ) . edited by jones : : jones , j . ( dublin school , 1933 ) .\n", + "41b : \" in thy neighbor , come together in love and grant thee comfort \" farrell , jl . w . \" what is good and what is evil ? \" , modern review , vol .\n" + ] + } + ], + "source": [ + "for i in range(50):\n", + " printer(sents[i], should_detokenize=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from nltk.translate import bleu_score as bleu" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Quality Measures" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "How similar are the generated sentences to the original training data (Toronto Book Corpus and Wikipedia dumps). We follow Yu et al., (2017) and compute the BLEU between the generations and the test sets of both corpora by treating the test set as the references for each generation. The tests sets are large; we subsample 5000 examples from each." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data(data_file, replacements={}, uncased=True):\n", + " data = [d.strip().split() for d in open(data_file, 'r').readlines()]\n", + " if uncased:\n", + " data = [[t.lower() for t in sent] for sent in data]\n", + " \n", + " for k, v in replacements.items():\n", + " data = [[t if t != k else v for t in sent] for sent in data]\n", + " \n", + " return data\n", + "\n", + "def prepare_wiki(data_file, uncased=True):\n", + " replacements = {\"@@unknown@@\": \"[UNK]\"}\n", + " return prepare_data(data_file, replacements=replacements, uncased=uncased)\n", + "\n", + "def prepare_tbc(data_file): \n", + " replacements = {\"``\": \"\\\"\", \"\\'\\'\": \"\\\"\"}\n", + " return prepare_data(data_file, replacements=replacements)\n", + "\n", + "def corpus_bleu(generated, references):\n", + " \"\"\" Compute similarity between two corpora as measured by\n", + " comparing each sentence of `generated` against all sentences in `references` \n", + " \n", + " args:\n", + " - generated (List[List[str]]): list of sentences (split into tokens)\n", + " - references (List[List[str]]): list of sentences (split into tokens)\n", + " \n", + " returns:\n", + " - bleu (float)\n", + " \"\"\" \n", + " return bleu.corpus_bleu([references for _ in range(len(generated))], generated)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "wiki103_file = 'data/wiki103.5k.txt'\n", + "tbc_file = 'data/tbc.5k.txt'\n", + "\n", + "wiki_data = prepare_wiki(wiki103_file)\n", + "tbc_data = prepare_tbc(tbc_file)\n", + "#sents = [detokenize(sent) for sent in sents]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BERT-TBC BLEU: 7.06\n", + "BERT-Wiki103 BLEU: 7.80\n", + "BERT-{TBC + Wiki103} BLEU: 8.27\n" + ] + } + ], + "source": [ + "print(\"BERT-TBC BLEU: %.2f\" % (100 * corpus_bleu(bert_sents, tbc_data)))\n", + "print(\"BERT-Wiki103 BLEU: %.2f\" % (100 * corpus_bleu(bert_sents, wiki_data)))\n", + "print(\"BERT-{TBC + Wiki103} BLEU: %.2f\" % (100 * corpus_bleu(bert_sents, tbc_data[:2500] + wiki_data[:2500])))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing to existing models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The OpenAI Generative Pretraining Transformer is another pretrained model successfully used for transfer learning. Since the model is a unidirectional language model, we can straightforwardly generate from the model. See [this repo](https://github.com/huggingface/pytorch-openai-transformer-lm) by Thomas Wolf at Huggingface for instructions for setting up the model." + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.insert(1, os.path.join(\".\", \"pytorch-openai-transformer-lm\"))\n", + "\n", + "from model_pytorch import LMModel, load_openai_pretrained_model, DEFAULT_CONFIG\n", + "from text_utils import TextEncoder\n", + "\n", + "def load_openai_gpt(n_special=1, n_ctx=512):\n", + " text_encoder = TextEncoder(\"pytorch-openai-transformer-lm/model/encoder_bpe_40000.json\", \n", + " \"pytorch-openai-transformer-lm/model/vocab_40000.bpe\")\n", + " encoder = text_encoder.encoder\n", + " n_vocab = len(text_encoder.encoder)\n", + " vocab = n_vocab + n_special + n_ctx\n", + "\n", + " args = DEFAULT_CONFIG\n", + " lm_model = LMModel(args, vocab, n_ctx, return_probs=True)\n", + " load_openai_pretrained_model(lm_model.transformer, n_ctx=n_ctx, n_special=n_special,\n", + " path=\"pytorch-openai-transformer-lm/model/\",\n", + " path_names=\"pytorch-openai-transformer-lm/\")\n", + " #lm_model.to(device)\n", + " lm_model.return_probs = False\n", + " lm_model.eval()\n", + " return lm_model, text_encoder\n", + "\n", + "def make_batch(X, n_vocab, n_special, batch_size):\n", + " X = np.array(X)\n", + " assert X.ndim in [1, 2]\n", + " if X.ndim == 1:\n", + " X = np.expand_dims(X, axis=0)\n", + " pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])\n", + " pos_enc = np.tile(pos_enc, (batch_size, pos_enc.shape[-1])) #np.expand_dims(pos_enc, axis=0)\n", + " batch = np.stack([X, pos_enc], axis=-1)\n", + " batch = torch.tensor(batch, dtype=torch.long)#.to(device)\n", + " return batch\n", + "\n", + "def append_batch(X, next_idx):\n", + " next_pos = X[:, -1:, 1] + 1\n", + " next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)\n", + " return torch.cat((X, next_x), 1)\n", + "\n", + "def _generate_sentence_openai(model, text_encoder, seed_text, batch_size=10, gen_len=20, \n", + " topk=100, sample=True, n_special=0):\n", + " n_vocab = len(text_encoder.encoder)\n", + " #X = np.random.randint(n_vocab, size=(batch_size, 1)).tolist()\n", + " #sents = [[text_encoder.decoder[X[i][0]]].replace('', '') for i in range(batch_size)]\n", + " X = [[n_vocab - 1] for _ in range(batch_size)]\n", + " sents = [[] for _ in range(batch_size)]\n", + " if seed_text:\n", + " seed_ids = text_encoder.encode([seed_text,])\n", + " X = [X[i] + seed_ids[0] for i in range(batch_size)]\n", + " sents = [[seed_text] for _ in range(batch_size)]\n", + " XMB = make_batch(X, n_vocab, n_special, batch_size=batch_size)\n", + "\n", + "\n", + " for step_n in range(gen_len):\n", + " out = model(XMB) + model.pos_emb_mask\n", + " next_idxs = generate_step(out, gen_idx=step_n, top_k=topk, sample=sample, return_list=False)\n", + " idxs = next_idxs.tolist()\n", + " for i in range(batch_size):\n", + " next_token = idxs[i]\n", + " if next_token == n_vocab:\n", + " next_token = \"\"\n", + " else:\n", + " next_token = text_encoder.decoder[next_token].replace('', '')\n", + " sents[i].append(next_token)\n", + " XMB = append_batch(XMB, next_idxs.unsqueeze(-1))\n", + " \n", + " return [[tok for tok in sent if tok != '\\n'] for sent in sents]\n", + "\n", + "def generate_openai(model, text_encoder, n_samples, seed_text, \n", + " batch_size=10, gen_len=20, \n", + " topk=100, temperature=temperature, sample=sample,\n", + " n_special=0, print_every=1):\n", + " sents = []\n", + " start_time = time.time()\n", + " n_batches = math.ceil(n_samples / batch_size)\n", + " for batch_n in range(n_batches):\n", + " batch_sents = _generate_sentence_openai(model, text_encoder, seed_text,\n", + " batch_size=batch_size, gen_len=gen_len, \n", + " topk=topk, sample=sample,\n", + " n_special=n_special)\n", + " sents += batch_sents\n", + " if (batch_n + 1) % print_every == 0:\n", + " print(\"Generated batch %d of %d in %.3fs\" % (batch_n + 1, n_batches, time.time() - start_time))\n", + " start_time = time.time()\n", + " return sents" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading weights...\n" + ] + } + ], + "source": [ + "gpt_model, gpt_text_encoder = load_openai_gpt(n_special=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated batch 1 of 20 in 137.174s\n", + "Generated batch 2 of 20 in 136.927s\n", + "Generated batch 3 of 20 in 137.247s\n", + "Generated batch 4 of 20 in 137.561s\n", + "Generated batch 5 of 20 in 137.296s\n", + "Generated batch 6 of 20 in 136.965s\n", + "Generated batch 7 of 20 in 136.721s\n", + "Generated batch 8 of 20 in 144.196s\n", + "Generated batch 9 of 20 in 182.620s\n", + "Generated batch 10 of 20 in 136.998s\n", + "Generated batch 11 of 20 in 133.163s\n", + "Generated batch 12 of 20 in 132.677s\n", + "Generated batch 13 of 20 in 133.620s\n", + "Generated batch 14 of 20 in 133.342s\n", + "Generated batch 15 of 20 in 132.501s\n", + "Generated batch 16 of 20 in 132.636s\n", + "Generated batch 17 of 20 in 133.013s\n", + "Generated batch 18 of 20 in 132.587s\n", + "Generated batch 19 of 20 in 133.053s\n", + "Generated batch 20 of 20 in 132.624s\n" + ] + } + ], + "source": [ + "n_samples = 1000\n", + "batch_size = 50\n", + "max_len = 40\n", + "top_k = 100\n", + "temperature = 0.7\n", + "\n", + "leed_out_len = 5 # max_len\n", + "burnin = 250\n", + "sample = True\n", + "max_iter = 500\n", + "\n", + "openai_sents = generate_openai(gpt_model, gpt_text_encoder, seed_text=\"\", \n", + " n_samples=n_samples, batch_size=batch_size, gen_len=max_len,\n", + " topk=top_k, temperature=temperature, sample=sample,\n", + " n_special=1, print_every=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\" you ca n't tell me you do n't like me ! \" \" of course i like you . \" \" then why are you trying to run away with me ? \" she asked .\n" + ] + } + ], + "source": [ + "printer(openai_sents[9], should_detokenize=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "print(\"GPT-TBC BLEU: %.2f\" % (100 * corpus_bleu(openai_sents, tbc_data)))\n", + "print(\"GPT-Wiki103 BLEU: %.2f\" % (100 * corpus_bleu(openai_sents, wiki_data)))\n", + "print(\"GPT-{TBC + Wiki103} BLEU: %.2f\" % (100 * corpus_bleu(openai_sents, tbc_data[:2500] + wiki_data[:2500])))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Diversity Measures" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Self-BLEU: treat each sentence as a hypothesis and treat rest of corpus as reference. Lower is better." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import Counter\n", + "from nltk.util import ngrams\n", + "\n", + "def self_bleu(sents):\n", + " return bleu.corpus_bleu([[s for (j, s) in enumerate(sents) if j != i] for i in range(len(sents))], sents)\n", + "\n", + "def get_ngram_counts(sents, max_n=4):\n", + " size2count = {}\n", + " for i in range(1, max_n + 1):\n", + " size2count[i] = Counter([n for sent in sents for n in ngrams(sent, i)])\n", + " return size2count\n", + "\n", + "def ref_unique_ngrams(preds, refs, max_n=4):\n", + " # get # of *distinct* pred ngrams that don't appear in ref\n", + " pct_unique = {}\n", + " pred_ngrams = get_ngram_counts(preds, max_n)\n", + " ref_ngrams = get_ngram_counts(refs, max_n)\n", + " for i in range(1, max_n + 1):\n", + " pred_ngram_counts = set(pred_ngrams[i].keys())\n", + " total = sum(pred_ngrams[i].values())\n", + " ref_ngram_counts = set(ref_ngrams[i].keys())\n", + " pct_unique[i] = len(pred_ngram_counts.difference(ref_ngram_counts)) / total\n", + " return pct_unique\n", + " \n", + "def self_unique_ngrams(preds, max_n=4):\n", + " # get # of pred ngrams with count 1\n", + " pct_unique = {}\n", + " pred_ngrams = get_ngram_counts(preds, max_n)\n", + " for i in range(1, max_n + 1):\n", + " n_unique = len([k for k, v in pred_ngrams[i].items() if v == 1])\n", + " total = sum(pred_ngrams[i].values())\n", + " pct_unique[i] = n_unique / total\n", + " return pct_unique" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BERT self-BLEU: 10.06\n" + ] + } + ], + "source": [ + "print(\"BERT self-BLEU: %.2f\" % (100 * self_bleu(bert_sents)))\n", + "print(\"OpenAI self-BLEU: %.2f\" % (100 * self_bleu(openai_sents)))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/beegfs/aw3272/software/miniconda3/envs/mtl-sent-rep/lib/python3.6/site-packages/ipykernel_launcher.py:10: DeprecationWarning: generator 'ngrams' raised StopIteration\n", + " # Remove the CWD from sys.path while we load stuff.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unique 1-grams relative to Wiki: 8.01\n", + "Unique 2-grams relative to Wiki: 57.90\n", + "Unique 3-grams relative to Wiki: 91.72\n", + "Unique 4-grams relative to Wiki: 98.55\n", + "Unique 1-grams relative to TBC: 10.55\n", + "Unique 2-grams relative to TBC: 60.94\n", + "Unique 3-grams relative to TBC: 92.04\n", + "Unique 4-grams relative to TBC: 98.56\n" + ] + } + ], + "source": [ + "max_n = 4\n", + "\n", + "pct_uniques = ref_unique_ngrams(bert_sents, wiki_data, max_n)\n", + "for i in range(1, max_n + 1):\n", + " print(\"BERT unique %d-grams relative to Wiki: %.2f\" % (i, 100 * pct_uniques[i]))\n", + "pct_uniques = ref_unique_ngrams(bert_sents, tbc_data, max_n)\n", + "for i in range(1, max_n + 1):\n", + " print(\"BERT unique %d-grams relative to TBC: %.2f\" % (i, 100 * pct_uniques[i]))\n", + "pct_uniques = self_unique_ngrams(bert_sents, max_n)\n", + "for i in range(1, max_n + 1):\n", + " print(\"BERT unique %d-grams relative to self: %.2f\" % (i, 100 * pct_uniques[i]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pct_uniques = ref_unique_ngrams(openai_sents, wiki_data, max_n)\n", + "for i in range(1, max_n + 1):\n", + " print(\"GPT unique %d-grams relative to Wiki: %.2f\" % (i, 100 * pct_uniques[i]))\n", + "pct_uniques = ref_unique_ngrams(openai_sents, tbc_data, max_n)\n", + "for i in range(1, max_n + 1):\n", + " print(\"GPT unique %d-grams relative to TBC: %.2f\" % (i, 100 * pct_uniques[i]))\n", + "pct_uniques = self_unique_ngrams(openai_sents, max_n)\n", + "for i in range(1, max_n + 1):\n", + " print(\"GPT unique %d-grams relative to self: %.2f\" % (i, 100 * pct_uniques[i]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:bert]", + "language": "python", + "name": "conda-env-bert-py" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/bert-babble.ipynb b/bert-babble.ipynb index 7e37f45..4d666a7 100644 --- a/bert-babble.ipynb +++ b/bert-babble.ipynb @@ -190,7 +190,7 @@ " for jj in range(batch_size):\n", " batch[jj][seed_len+ii] = idxs[jj]\n", " \n", - " return untokenize_batch(batch)\n", + " return untokenize_batch(batch)\n", "\n", "\n", "def generate(n_samples, seed_text=\"[CLS]\", batch_size=10, max_len=25, \n",