Spaces:
Build error
Build error
Commit
·
7275554
1
Parent(s):
b66f09d
Adjusted heatmap size and device printing
Browse files
app.py
CHANGED
|
@@ -39,12 +39,12 @@ def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=No
|
|
| 39 |
return all_single_mutants
|
| 40 |
|
| 41 |
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None,annotate=True,fontsize=20):
|
|
|
|
|
|
|
| 42 |
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
|
| 43 |
-
fig, ax = plt.subplots(figsize=(50,len(sequence)
|
| 44 |
scores_dict = {}
|
| 45 |
valid_mutant_set=set(scores.mutant)
|
| 46 |
-
if mutation_range_start is None: mutation_range_start=1
|
| 47 |
-
if mutation_range_end is None: mutation_range_end=len(sequence)
|
| 48 |
ax.tick_params(bottom=True, top=True, left=True, right=True)
|
| 49 |
ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
|
| 50 |
if annotate:
|
|
@@ -63,7 +63,6 @@ def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_rang
|
|
| 63 |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
|
| 64 |
heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
|
| 65 |
heat.figure.axes[-1].yaxis.set_ticklabels(heat.figure.axes[-1].yaxis.get_ticklabels(), fontsize=fontsize)
|
| 66 |
-
#heat.figure.axes[-1].yaxis.set_ticklabels(fontsize=fontsize)
|
| 67 |
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
|
| 68 |
heat.set_ylabel("Sequence position", fontsize = fontsize*2)
|
| 69 |
heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
|
|
@@ -87,7 +86,6 @@ def suggest_mutations(scores):
|
|
| 87 |
positive_scores = scores[scores.avg_score > 0]
|
| 88 |
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
|
| 89 |
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
|
| 90 |
-
print(top_positions)
|
| 91 |
position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
|
| 92 |
return intro_message+mutant_recos+position_recos
|
| 93 |
|
|
@@ -115,6 +113,11 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
|
|
| 115 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
|
| 116 |
elif model_type=="Large":
|
| 117 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
model.config.tokenizer = tokenizer
|
| 119 |
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
| 120 |
scores = model.score_mutants(DMS_data=all_single_mutants,
|
|
@@ -205,10 +208,10 @@ with tranception_design:
|
|
| 205 |
)
|
| 206 |
gr.Markdown("<br>")
|
| 207 |
gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
|
| 208 |
-
|
| 209 |
#output_plot = gr.Plot(label="Fitness predictions for all single amino acid substitutions in mutation range")
|
| 210 |
#output_image = gr.Image(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath")
|
| 211 |
-
output_image = gr.Gallery(label="Fitness predictions
|
| 212 |
|
| 213 |
output_recommendations = gr.Textbox(label="Mutation recommendations")
|
| 214 |
|
|
|
|
| 39 |
return all_single_mutants
|
| 40 |
|
| 41 |
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None,annotate=True,fontsize=20):
|
| 42 |
+
if mutation_range_start is None: mutation_range_start=1
|
| 43 |
+
if mutation_range_end is None: mutation_range_end=len(sequence)
|
| 44 |
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
|
| 45 |
+
fig, ax = plt.subplots(figsize=(min(len(sequence),50),len(sequence)))
|
| 46 |
scores_dict = {}
|
| 47 |
valid_mutant_set=set(scores.mutant)
|
|
|
|
|
|
|
| 48 |
ax.tick_params(bottom=True, top=True, left=True, right=True)
|
| 49 |
ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
|
| 50 |
if annotate:
|
|
|
|
| 63 |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
|
| 64 |
heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
|
| 65 |
heat.figure.axes[-1].yaxis.set_ticklabels(heat.figure.axes[-1].yaxis.get_ticklabels(), fontsize=fontsize)
|
|
|
|
| 66 |
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
|
| 67 |
heat.set_ylabel("Sequence position", fontsize = fontsize*2)
|
| 68 |
heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
|
|
|
|
| 86 |
positive_scores = scores[scores.avg_score > 0]
|
| 87 |
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
|
| 88 |
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
|
|
|
|
| 89 |
position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
|
| 90 |
return intro_message+mutant_recos+position_recos
|
| 91 |
|
|
|
|
| 113 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
|
| 114 |
elif model_type=="Large":
|
| 115 |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
|
| 116 |
+
if torch.cuda.is_available():
|
| 117 |
+
model.cuda()
|
| 118 |
+
print("Inference will take place on GPU")
|
| 119 |
+
else:
|
| 120 |
+
print("Inference will take place on CPU")
|
| 121 |
model.config.tokenizer = tokenizer
|
| 122 |
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
| 123 |
scores = model.score_mutants(DMS_data=all_single_mutants,
|
|
|
|
| 208 |
)
|
| 209 |
gr.Markdown("<br>")
|
| 210 |
gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
|
| 211 |
+
gr.Markdown("Inference may take a few seconds for short proteins & mutation ranges to several minutes for longer ones")
|
| 212 |
#output_plot = gr.Plot(label="Fitness predictions for all single amino acid substitutions in mutation range")
|
| 213 |
#output_image = gr.Image(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath")
|
| 214 |
+
output_image = gr.Gallery(label="Fitness predictions for all single amino acid substitutions in mutation range",type="filepath") #Using Gallery to be able to scroll large matrix images
|
| 215 |
|
| 216 |
output_recommendations = gr.Textbox(label="Mutation recommendations")
|
| 217 |
|