diff --git a/README.md b/README.md index f32a559a..feca7bfc 100644 --- a/README.md +++ b/README.md @@ -85,25 +85,18 @@ activity = client.get_activity(123) Streams represent the raw data of the uploaded file. Activities, efforts, and segments all have streams. There are many types of streams, if activity does -not have requested stream type, it will not be part of returned set. +not have requested stream type, returned set simply won't include it. ```python -# Activities have many streams, you can request desired stream type -streams = client.get_activity_streams(123, types=['altitude',], resolution='low') - -# Result is always an enum object -stream = next(streams) -print(stream.type) -print(stream.data) - -# You can request many stream types +# Activities can have many streams, you can request n desired stream types types = ['time', 'latlng', 'altitude', 'heartrate', 'temp', ] -for stream in client.get_activity_streams(123, types=types, resolution='medium'): - print(stream.type) - print(stream.data) +streams = client.get_activity_streams(123, types=types, resolution='medium') +# Result is a dictionary object. The dict's key are the stream type. +if 'altitude' in streams.keys(): + print(streams['altitude'].data) ``` diff --git a/stravalib/client.py b/stravalib/client.py index 2f20b4f4..1d85e193 100644 --- a/stravalib/client.py +++ b/stravalib/client.py @@ -18,6 +18,11 @@ from stravalib.util import limiter from stravalib import unithelper +try: + unicode +except: + unicode = str + class Client(object): """ Main client class for interacting with the exposed Strava v3 API methods. @@ -641,6 +646,7 @@ def get_starred_segment(self, limit=None): :rtype: :class:`stravalib.model.Segment` """ + params = {} if limit is not None: params["limit"] = limit @@ -765,12 +771,12 @@ def get_segment_efforts(self, segment_id, athlete_id=None, params['athlete_id'] = athlete_id if start_date_local: - if isinstance(start_date_local, str): + if isinstance(start_date_local, (str, unicode)): start_date_local = dateparser.parse(start_date_local, ignoretz=True) params["start_date_local"] = start_date_local.strftime("%Y-%m-%dT%H:%M:%SZ") if end_date_local: - if isinstance(end_date_local, str): + if isinstance(end_date_local, (str, unicode)): end_date_local = dateparser.parse(end_date_local, ignoretz=True) params["end_date_local"] = end_date_local.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -869,13 +875,18 @@ def get_activity_streams(self, activity_id, types=None, result_fetcher = functools.partial( self.protocol.get, - '/activities/{id}/streams/{types}'.format(id=activity_id, - types=types), - **params) + '/activities/{id}/streams/{types}'.format( + id=activity_id, + types=types), + **params) + + streams = BatchedResultsIterator(entity=model.Stream, + bind_client=self, + result_fetcher=result_fetcher) + + # Pack streams into dictionary + return {i.type : i for i in streams} - return BatchedResultsIterator(entity=model.Stream, - bind_client=self, - result_fetcher=result_fetcher) def get_effort_streams(self, effort_id, types=None, resolution=None, series_type=None): @@ -928,10 +939,12 @@ def get_effort_streams(self, effort_id, types=None, types=types), **params) - return BatchedResultsIterator(entity=model.Stream, - bind_client=self, - result_fetcher=result_fetcher) + streams = BatchedResultsIterator(entity=model.Stream, + bind_client=self, + result_fetcher=result_fetcher) + # Pack streams into dictionary + return {i.type : i for i in streams} def get_segment_streams(self, segment_id, types=None, resolution=None, series_type=None): @@ -984,9 +997,12 @@ def get_segment_streams(self, segment_id, types=None, types=types), **params) - return BatchedResultsIterator(entity=model.Stream, - bind_client=self, - result_fetcher=result_fetcher) + streams = BatchedResultsIterator(entity=model.Stream, + bind_client=self, + result_fetcher=result_fetcher) + + # Pack streams into dictionary + return {i.type : i for i in streams} diff --git a/stravalib/tests/functional/test_client.py b/stravalib/tests/functional/test_client.py index dff51377..4315f3c5 100644 --- a/stravalib/tests/functional/test_client.py +++ b/stravalib/tests/functional/test_client.py @@ -6,6 +6,17 @@ import datetime class ClientTest(FunctionalTestBase): + def test_get_starred_segment(self): + """ + Test get_starred_segment + """ + i = 0 + for segment in self.client.get_starred_segment(limit=5): + self.assertIsInstance(segment, model.Segment) + i+=1 + self.assertGreater(i, 0) # star at least one segment + self.assertLessEqual(i, 5) + def test_get_activity(self): """ Test basic activity fetching. """ @@ -88,23 +99,21 @@ def test_activity_streams(self): 'heartrate', 'cadence', 'watts', 'temp', 'moving', 'grade_smooth'] - d = {} - for stream in self.client.get_activity_streams(152668627, stypes, 'low'): - d[stream.type] = stream + streams = self.client.get_activity_streams(152668627, stypes, 'low') - self.assertGreater(d.keys(), 3) - for k in d.keys(): + self.assertGreater(len(streams.keys()), 3) + for k in streams.keys(): self.assertIn(k, stypes) # time stream - self.assertIsInstance(d['time'].data[0], int) - self.assertGreater(d['time'].original_size, 100) - self.assertEqual(d['time'].resolution, 'low') - self.assertEqual(len(d['time'].data), 100) + self.assertIsInstance(streams['time'].data[0], int) + self.assertGreater(streams['time'].original_size, 100) + self.assertEqual(streams['time'].resolution, 'low') + self.assertEqual(len(streams['time'].data), 100) # latlng stream - self.assertIsInstance(d['latlng'].data[0], list) - self.assertIsInstance(d['latlng'].data[0][0], float) + self.assertIsInstance(streams['latlng'].data, list) + self.assertIsInstance(streams['latlng'].data[0][0], float) def test_effort_streams(self): """ @@ -112,21 +121,18 @@ def test_effort_streams(self): """ stypes = ['distance'] - activity = self.client.get_activity(152668627) #165479860 + activity = self.client.get_activity(165479860) #152668627) streams = self.client.get_effort_streams(activity.segment_efforts[0].id, stypes, 'medium') - d = {} - for stream in streams: - d[stream.type] = stream - self.assertEqual(d.keys(), ['distance']) + self.assertIn('distance', streams.keys()) # distance stream - self.assertIsInstance(d['distance'].data[0], float) - self.assertEqual(d['distance'].resolution, 'medium') - self.assertEqual(len(d['distance'].data), - min(1000, d['distance'].original_size)) + self.assertIsInstance(streams['distance'].data[0], float) #xxx + self.assertEqual(streams['distance'].resolution, 'medium') + self.assertEqual(len(streams['distance'].data), + min(1000, streams['distance'].original_size)) def test_get_curr_athlete(self): @@ -186,7 +192,7 @@ def test_get_segment_leaderboard(self): for i,e in enumerate(lb): print '{0}: {1}'.format(i, e) - self.assertEquals(15, len(lb.entries)) # 10 top results, 5 bottom results + self.assertEquals(10, len(lb.entries)) # 10 top results self.assertIsInstance(lb.entries[0], model.SegmentLeaderboardEntry) self.assertEquals(1, lb.entries[0].rank) self.assertTrue(lb.effort_count > 8000) # At time of writing 8206 @@ -214,7 +220,7 @@ def test_get_segment(self): # Fetch leaderboard lb = segment.leaderboard - self.assertEquals(15, len(lb)) # 10 top results, 5 bottom results + self.assertEquals(10, len(lb)) # 10 top results, 5 bottom results def test_get_segment_efforts(self): # test with string