Spaces:
Build error
Build error
Update backend_utils.py
Browse files- backend_utils.py +6 -0
backend_utils.py
CHANGED
|
@@ -469,6 +469,10 @@ def make_predictions(input_query,
|
|
| 469 |
'''
|
| 470 |
library_ids, library_names = retrieve_libraries(model_retrieval, input_query, db_metadata)
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
predictions = generate_api_usage_patterns_batch(
|
| 473 |
model_generative,
|
| 474 |
tokenizer_generative,
|
|
@@ -480,6 +484,7 @@ def make_predictions(input_query,
|
|
| 480 |
config.get('max_length_generate')
|
| 481 |
)
|
| 482 |
|
|
|
|
| 483 |
hw_configs = predict_hw_config(
|
| 484 |
model_classifier,
|
| 485 |
tokenizer_classifier,
|
|
@@ -492,6 +497,7 @@ def make_predictions(input_query,
|
|
| 492 |
for output_dict, hw_config in zip(predictions, hw_configs):
|
| 493 |
output_dict['hw_config'] = hw_config
|
| 494 |
|
|
|
|
| 495 |
predictions = get_metadata_library(predictions, db_metadata)
|
| 496 |
|
| 497 |
return predictions
|
|
|
|
| 469 |
'''
|
| 470 |
library_ids, library_names = retrieve_libraries(model_retrieval, input_query, db_metadata)
|
| 471 |
|
| 472 |
+
if len(library_ids) == 0:
|
| 473 |
+
return {'status': 999}
|
| 474 |
+
|
| 475 |
+
print("generate usage patterns")
|
| 476 |
predictions = generate_api_usage_patterns_batch(
|
| 477 |
model_generative,
|
| 478 |
tokenizer_generative,
|
|
|
|
| 484 |
config.get('max_length_generate')
|
| 485 |
)
|
| 486 |
|
| 487 |
+
print("generate hw configs")
|
| 488 |
hw_configs = predict_hw_config(
|
| 489 |
model_classifier,
|
| 490 |
tokenizer_classifier,
|
|
|
|
| 497 |
for output_dict, hw_config in zip(predictions, hw_configs):
|
| 498 |
output_dict['hw_config'] = hw_config
|
| 499 |
|
| 500 |
+
print("finished the predictions")
|
| 501 |
predictions = get_metadata_library(predictions, db_metadata)
|
| 502 |
|
| 503 |
return predictions
|