Skip to content

Commit 39e838a

Browse files
authored
Merge pull request #203 from roboflow/fix/add-fine-tuned-models
Add finetuned models
2 parents 1144f66 + 3eb1e43 commit 39e838a

File tree

2 files changed

+24
-20
lines changed

2 files changed

+24
-20
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from roboflow.models import CLIPModel, GazeModel
1414
from roboflow.util.general import write_line
1515

16-
__version__ = "1.1.8"
16+
__version__ = "1.1.9"
1717

1818

1919
def check_key(api_key, model, notebook, num_retries=0):

roboflow/models/inference.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def predict_video(
239239
signed_url = video_path
240240

241241
url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key)
242-
243242
if model_class in ("CLIPModel", "GazeModel"):
244243
if model_class == "CLIPModel":
245244
model = "clip"
@@ -257,6 +256,14 @@ def predict_video(
257256
],
258257
}
259258
]
259+
else:
260+
models = [
261+
{
262+
"model_id": self.dataset_id,
263+
"model_version": self.version,
264+
"inference_type": self.type,
265+
}
266+
]
260267

261268
for model in additional_models:
262269
models.append(SUPPORTED_ADDITIONAL_MODELS[model])
@@ -308,28 +315,28 @@ def poll_for_video_results(self, job_id: str = None) -> dict:
308315
url = urljoin(
309316
API_URL, "/videoinfer/?api_key=" + self.__api_key + "&job_id=" + self.job_id
310317
)
311-
312318
try:
313319
response = requests.get(url, headers={"Content-Type": "application/json"})
314320
except Exception as e:
315321
raise Exception(f"Error getting video inference results: {e}")
316322

317323
if not response.ok:
318324
raise Exception(f"Error getting video inference results: {response.text}")
319-
320325
data = response.json()
326+
if "status" not in data:
327+
return {} # No status available
328+
if data.get("status") > 1:
329+
return data # Error
330+
elif data.get("status") == 1:
331+
return {} # Still running
332+
else: # done
333+
output_signed_url = data["output_signed_url"]
334+
inference_data = requests.get(
335+
output_signed_url, headers={"Content-Type": "application/json"}
336+
)
321337

322-
if data.get("status") != 0:
323-
return {}
324-
325-
output_signed_url = data["output_signed_url"]
326-
327-
inference_data = requests.get(
328-
output_signed_url, headers={"Content-Type": "application/json"}
329-
)
330-
331-
# frame_offset and model name are top-level keys
332-
return inference_data.json()
338+
# frame_offset and model name are top-level keys
339+
return inference_data.json()
333340

334341
def poll_until_video_results(self, job_id) -> dict:
335342
"""
@@ -357,14 +364,11 @@ def poll_until_video_results(self, job_id) -> dict:
357364
job_id = self.job_id
358365

359366
attempts = 0
360-
367+
print(f"Checking for video inference results for job {job_id} every 60s")
361368
while True:
369+
time.sleep(60)
362370
print(f"({attempts * 60}s): Checking for inference results")
363-
364371
response = self.poll_for_video_results()
365-
366-
time.sleep(60)
367-
368372
attempts += 1
369373

370374
if response != {}:

0 commit comments

Comments
 (0)