|
4 | 4 | import json |
5 | 5 | import os |
6 | 6 | import random |
7 | | -import sys |
8 | 7 | import urllib |
9 | 8 |
|
10 | 9 | import cv2 |
11 | 10 | import numpy as np |
12 | 11 | import requests |
13 | 12 | from PIL import Image |
14 | | -from tqdm import tqdm |
15 | 13 |
|
16 | | -from roboflow.config import API_URL, OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL |
| 14 | +from roboflow.config import OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL |
17 | 15 | from roboflow.models.inference import InferenceModel |
18 | 16 | from roboflow.util.image_utils import check_image_url |
19 | 17 | from roboflow.util.prediction import PredictionGroup |
@@ -461,56 +459,6 @@ def view(button): |
461 | 459 | else: |
462 | 460 | view(stopButton) |
463 | 461 |
|
464 | | - def download(self, format="pt", location="."): |
465 | | - """ |
466 | | - Download the weights associated with a model. |
467 | | -
|
468 | | - Args: |
469 | | - format (str): The format of the output. |
470 | | - - 'pt': returns a PyTorch weights file |
471 | | - location (str): The location to save the weights file to |
472 | | - """ |
473 | | - supported_formats = ["pt"] |
474 | | - if format not in supported_formats: |
475 | | - raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}") |
476 | | - |
477 | | - workspace, project, version = self.id.rsplit("/") |
478 | | - |
479 | | - # get pt url |
480 | | - pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile" |
481 | | - |
482 | | - r = requests.get(pt_api_url, params={"api_key": self.__api_key}) |
483 | | - |
484 | | - r.raise_for_status() |
485 | | - |
486 | | - pt_weights_url = r.json()["weightsUrl"] |
487 | | - |
488 | | - def bar_progress(current, total, width=80): |
489 | | - progress_message = ( |
490 | | - "Downloading weights to " |
491 | | - + location |
492 | | - + "/weights.pt" |
493 | | - + ": %d%% [%d / %d] bytes" % (current / total * 100, current, total) |
494 | | - ) |
495 | | - sys.stdout.write("\r" + progress_message) |
496 | | - sys.stdout.flush() |
497 | | - |
498 | | - response = requests.get(pt_weights_url, stream=True) |
499 | | - |
500 | | - # write the zip file to the desired location |
501 | | - with open(location + "/weights.pt", "wb") as f: |
502 | | - total_length = int(response.headers.get("content-length")) |
503 | | - for chunk in tqdm( |
504 | | - response.iter_content(chunk_size=1024), |
505 | | - desc=f"Downloading weights to {location}/weights.pt", |
506 | | - total=int(total_length / 1024) + 1, |
507 | | - ): |
508 | | - if chunk: |
509 | | - f.write(chunk) |
510 | | - f.flush() |
511 | | - |
512 | | - return |
513 | | - |
514 | 462 | def __exception_check(self, image_path_check=None): |
515 | 463 | # Check if Image path exists exception check |
516 | 464 | # (for both hosted URL and local image) |
|
0 commit comments