Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import h5py | |
| import numpy as np | |
| from functools import partial | |
| from utils.gen_utils import map_nlist, vround | |
| import regex as re | |
| from spacyface.simple_spacy_token import SimpleSpacyToken | |
| ZERO_BUFFER = 12 # Number of decimal places each index takes | |
| main_key = r"{:0" + str(ZERO_BUFFER) + r"}" | |
| suppl_attn_key = r"{:0" + str(ZERO_BUFFER) + r"}_attn" | |
| def zip_len_check(*iters): | |
| """Zip iterables with a check that they are all the same length""" | |
| if len(iters) < 2: | |
| raise ValueError(f"Expected at least 2 iterables to combine. Got {len(iters)} iterables") | |
| n = len(iters[0]) | |
| for i in iters: | |
| n_ = len(i) | |
| if n_ != n: | |
| raise ValueError(f"Expected all iterations to have len {n} but found {n_}") | |
| return zip(*iters) | |
| class SentenceH5Data: | |
| def __init__(self, grp): | |
| self.grp = grp | |
| def n_layers(self): | |
| return self.embeddings.shape[0] - 1 # 1 was added at the input, not a hidden layer | |
| def sentence(self): | |
| return self.grp.attrs['sentence'] | |
| def embeddings(self): | |
| return self.grp['embeddings'][:] | |
| def zero_special_embeddings(self): | |
| out = self.embeddings.copy() | |
| out[:, self.mask_is_special] = np.zeros(out[:, self.mask_is_special].shape) | |
| return out | |
| def contexts(self): | |
| return self.grp['contexts'][:] | |
| def zero_special_contexts(self): | |
| out = self.contexts.copy() | |
| out[:, self.mask_is_special] = np.zeros(out[:, self.mask_is_special].shape) | |
| return out | |
| def attentions(self): | |
| """Return all attentions, including [CLS] and [SEP] | |
| Note that if the hdf5 is created with CLS and SEP attentions, it will have CLS and SEP attentions""" | |
| return self.grp['attentions'][:] # Converts to numpy array | |
| def mask_is_special(self): | |
| return np.logical_or(self.deps == '', self.poss == '') | |
| def tokens(self): | |
| return self.grp.attrs['token'] | |
| def poss(self): | |
| return self.grp.attrs['pos'] | |
| def deps(self): | |
| return self.grp.attrs['dep'] | |
| def is_ents(self): | |
| return self.grp.attrs['is_ent'] | |
| def heads(self): | |
| """Not the attention heads, but rather the head word of the orig sentence""" | |
| return self.grp.attrs['head'] | |
| def norms(self): | |
| return self.grp.attrs['norm'] | |
| def tags(self): | |
| return self.grp.attrs['tag'] | |
| def lemmas(self): | |
| return self.grp.attrs['lemma'] | |
| def __len__(self): | |
| return len(self.tokens) | |
| def __repr__(self): | |
| sent_len = 40 | |
| if len(self.sentence) > sent_len: s = self.sentence[:(sent_len - 3)] + '...' | |
| else: s = self.sentence | |
| return f"SentenceH5Data({s})" | |
| class TokenH5Data(SentenceH5Data): | |
| """A wrapper around the HDF5 file storage information allowing easy access to information about each | |
| processed sentence. | |
| Sometimes, and index of -1 is used to represent the entire object in memory | |
| """ | |
| def __init__(self, grp, index): | |
| """Represents returned from the refmap of the CorpusEmbedding class""" | |
| if type(grp) == SentenceH5Data: super().__init__(grp.grp) | |
| elif type(grp) == h5py._hl.group.Group: super().__init__(grp) | |
| self.index = index | |
| def embedding(self): | |
| return self.embeddings[:, self.index, :] | |
| def context(self): | |
| return self.contexts[:, self.index, :] | |
| def attentions_out(self): | |
| """Access all attention OUT of this token""" | |
| output = self.attentions[:,:, self.index, :] | |
| return output | |
| def attentions_in(self): | |
| """Access all attention INTO this token""" | |
| new_attention = self.attentions.transpose((0,1,3,2)) | |
| return new_attention[:,:, self.index, :] | |
| def _select_from_attention(self, layer, heads): | |
| if type(heads) is int: | |
| heads = [heads] | |
| # Select layer and heads | |
| modified_attentions = self.attentions[layer, heads].mean(0) | |
| attentions_out = modified_attentions | |
| attentions_in = modified_attentions.transpose() | |
| return attentions_out, attentions_in | |
| def _calc_offset_single(self, attention): | |
| """Get offset to location of max attention""" | |
| curr_idx = self.index | |
| max_atts = np.argmax(attention) | |
| return max_atts - curr_idx | |
| # Define metadata properties. | |
| # Right now, needs manual curation of fields from SimpleSpacyToken. Ideally, this is automated | |
| def token(self): | |
| return self.tokens[self.index] | |
| def pos(self): | |
| return self.poss[self.index] | |
| def dep(self): | |
| return self.deps[self.index] | |
| def is_ent(self): | |
| return bool(self.is_ents[self.index]) | |
| def norm(self): | |
| return self.norms[self.index] | |
| def head(self): | |
| return self.heads[self.index] | |
| def lemma(self): | |
| return self.lemmas[self.index] | |
| def tag(self): | |
| return self.tags[self.index] | |
| def to_json(self, layer, heads, top_k=5, ndigits=4): | |
| """ | |
| Convert token information and attention to return to frontend | |
| Require layer, heads, and top_k to convert the attention into value to return to frontend. | |
| Output: | |
| { | |
| sentence: str | |
| index: number | |
| match: str | |
| is_match: bool | |
| is_next_word: bool | |
| matched_att: { | |
| in: { att: number[] | |
| , offset_to_max: number | |
| , loc_of_max: float | |
| } | |
| out: { att: number[] | |
| , offset_to_max: number | |
| , loc_of_max: float | |
| } | |
| }, | |
| matched_att_plus_1: { | |
| in: { att: number[] | |
| , offset_to_max: number | |
| } | |
| out: { att: number[] | |
| , offset_to_max: number | |
| } | |
| } | |
| tokens: List[ | |
| { token: string | |
| , pos: string | |
| , dep: string | |
| , is_ent: boolean | |
| , inward: number[] | |
| , outward: number[] | |
| } | |
| ] | |
| } | |
| """ | |
| keys = [ | |
| "token", | |
| "pos", | |
| "dep", | |
| "is_ent", | |
| "inward", | |
| "outward", | |
| ] | |
| token_arr = [] | |
| matched_attentions = {} | |
| N = len(self) | |
| # Iterate through the following | |
| tokens = self.tokens.tolist() | |
| poss = [p.lower() for p in self.poss.tolist()] | |
| deps = [d.lower() for d in self.deps.tolist()] | |
| ents = self.is_ents.tolist() | |
| attentions_out, attentions_in = self._select_from_attention(layer, heads) | |
| matched_att_plus_1 = None | |
| next_index = None | |
| for i, tok_info in enumerate(zip_len_check( | |
| tokens | |
| , poss | |
| , deps | |
| , ents | |
| , attentions_out.tolist() | |
| , attentions_in.tolist())): | |
| def get_interesting_attentions(): | |
| return { | |
| "in": { | |
| "att": att_in, | |
| "offset_to_max": self._calc_offset_single(att_in).item(), | |
| # "loc_of_max": np.argmax(att_in), # Broken | |
| }, | |
| "out": { | |
| "att": att_out, | |
| "offset_to_max": self._calc_offset_single(att_out).item(), | |
| # "loc_of_max": np.argmax(att_out), # Broken | |
| } | |
| } | |
| # Perform rounding of attentions | |
| rounder = partial(round, ndigits=ndigits) | |
| att_out = map_nlist(rounder, tok_info[-2]) | |
| att_in = map_nlist(rounder, tok_info[-1]) | |
| obj = {k: v for (k, v) in zip_len_check(keys, tok_info)} | |
| IS_LAST_TOKEN = i == (N-1) | |
| if (i == self.index) or ((i - 1) == self.index): | |
| interesting_attentions = get_interesting_attentions() | |
| if i == self.index: | |
| obj['is_match'] = True | |
| matched_attentions = interesting_attentions | |
| elif (i-1) == self.index: | |
| matched_att_plus_1 = interesting_attentions | |
| obj['is_next_word'] = True | |
| next_index = i | |
| # Edge case for final iteration through sentence | |
| else: | |
| obj['is_match'] = False | |
| obj['is_next_word'] = False | |
| if (IS_LAST_TOKEN and (matched_att_plus_1 is None)): | |
| print("Saving matched_att_plus_1 to: ", interesting_attentions) | |
| obj['is_next_word'] = True | |
| matched_att_plus_1 = get_interesting_attentions() | |
| next_index = i | |
| token_arr.append(obj) | |
| next_token = self.tokens[next_index] | |
| obj = { | |
| "sentence": self.sentence, | |
| "index": self.index, | |
| "match": self.token, | |
| "next_index": next_index, | |
| "match_plus_1": next_token, | |
| "matched_att": matched_attentions, | |
| "matched_att_plus_1": matched_att_plus_1, | |
| "tokens": token_arr, | |
| } | |
| return obj | |
| def __repr__(self): | |
| return f"{self.token}: [{self.pos}, {self.dep}, {self.is_ent}]" |