I am trying to run the predict function locally on my data that is in the right format. I am receiving the following error:
Cell In[86], line 2
1 predictor = TabPFNTimeSeriesPredictor(tabpfn_mode=TabPFNMode.LOCAL)
----> 2 pred = predictor.predict(train_tsdf, test_tsdf)
File ~/SageMaker/env/lib/python3.10/site-packages/tabpfn_time_series/predictor.py:48, in TabPFNTimeSeriesPredictor.predict(self, train_tsdf, test_tsdf)
40 """
41 Predict on each time series individually (local forecasting).
42 """
44 logger.info(
45 f"Predicting {len(train_tsdf.item_ids)} time series with config{self.tabpfn_worker.config}"
46 )
---> 48 return self.tabpfn_worker.predict(train_tsdf, test_tsdf)
File ~/SageMaker/env/lib/python3.10/site-packages/tabpfn_time_series/tabpfn_worker.py:172, in LocalTabPFN.predict(self, train_tsdf, test_tsdf)
166 item_ids_chunks = np.array_split(
167 np.random.permutation(train_tsdf.item_ids),
168 min(total_num_workers, len(train_tsdf.item_ids)),
169 )
171 # Run predictions in parallel
--> 172 predictions = Parallel(n_jobs=len(item_ids_chunks), backend="loky")(
173 delayed(self._prediction_routine_per_gpu)(
174 train_tsdf.loc[chunk],
175 test_tsdf.loc[chunk],
176 gpu_id=i
177 % torch.cuda.device_count(), # Alternate between available GPUs
178 )
179 for i, chunk in enumerate(item_ids_chunks)
180 )
182 predictions = pd.concat(predictions)
184 # Sort predictions according to original item_ids order
File ~/SageMaker/env/lib/python3.10/site-packages/joblib/parallel.py:2007, in Parallel.__call__(self, iterable)
2001 # The first item from the output is blank, but it makes the interpreter
2002 # progress until it enters the Try/Except block of the generator and
2003 # reaches the first `yield` statement. This starts the asynchronous
2004 # dispatch of the tasks to the workers.
2005 next(output)
-> 2007 return output if self.return_generator else list(output)
File ~/SageMaker/env/lib/python3.10/site-packages/joblib/parallel.py:1650, in Parallel._get_outputs(self, iterator, pre_dispatch)
1647 yield
1649 with self._backend.retrieval_context():
-> 1650 yield from self._retrieve()
1652 except GeneratorExit:
1653 # The generator has been garbage collected before being fully
1654 # consumed. This aborts the remaining tasks if possible and warn
1655 # the user if necessary.
1656 self._exception = True
File ~/SageMaker/env/lib/python3.10/site-packages/joblib/parallel.py:1754, in Parallel._retrieve(self)
1747 while self._wait_retrieval():
1748
1749 # If the callback thread of a worker has signaled that its task
1750 # triggered an exception, or if the retrieval loop has raised an
1751 # exception (e.g. `GeneratorExit`), exit the loop and surface the
1752 # worker traceback.
1753 if self._aborting:
-> 1754 self._raise_error_fast()
1755 break
1757 # If the next job is not ready for retrieval yet, we just wait for
1758 # async callbacks to progress.
File ~/SageMaker/env/lib/python3.10/site-packages/joblib/parallel.py:1789, in Parallel._raise_error_fast(self)
1785 # If this error job exists, immediately raise the error by
1786 # calling get_result. This job might not exists if abort has been
1787 # called directly or if the generator is gc'ed.
1788 if error_job is not None:
-> 1789 error_job.get_result(self.timeout)
File ~/SageMaker/env/lib/python3.10/site-packages/joblib/parallel.py:745, in BatchCompletionCallBack.get_result(self, timeout)
739 backend = self.parallel._backend
741 if backend.supports_retrieve_callback:
742 # We assume that the result has already been retrieved by the
743 # callback thread, and is stored internally. It's just waiting to
744 # be returned.
--> 745 return self._return_or_raise()
747 # For other backends, the main thread needs to run the retrieval step.
748 try:
File ~/SageMaker/env/lib/python3.10/site-packages/joblib/parallel.py:763, in BatchCompletionCallBack._return_or_raise(self)
761 try:
762 if self.status == TASK_ERROR:
--> 763 raise self._result
764 return self._result
765 finally:
AttributeError: 'numpy.float64' object has no attribute 'nunique'```
I am trying to run the predict function locally on my data that is in the right format. I am receiving the following error: