@@ -1315,97 +1315,4 @@ Below is the documentation on the available arguments.
1315
1315
Interval of validation in training
1316
1316
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
1317
1317
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
1318
- --share-model
1319
-
1320
- Model initialization using the Torch API
1321
- ----------------------------------------
1322
-
1323
- The scikit-learn API provides parametrization to many common use cases.
1324
- The Torch API however allows for more flexibility and customization, for e.g.
1325
- sampling, criterions, and data loaders.
1326
-
1327
- In this minimal example we show how to initialize a CEBRA model using the Torch API.
1328
- Here the :py:class: `cebra.data.single_session.DiscreteDataLoader `
1329
- gets initilized which also allows the `prior ` to be directly parametrized.
1330
-
1331
- 👉 For an example notebook using the Torch API check out the :doc: `demo_notebooks/Demo_Allen `.
1332
-
1333
-
1334
- .. testcode ::
1335
-
1336
- import numpy as np
1337
- import cebra.datasets
1338
- from cebra import plot_embedding
1339
- import torch
1340
-
1341
- if torch.cuda.is_available():
1342
- device = "cuda"
1343
- else:
1344
- device = "cpu"
1345
-
1346
- neural_data = cebra.load_data(file="neural_data.npz", key="neural")
1347
-
1348
- discrete_label = cebra.load_data(
1349
- file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
1350
- )
1351
-
1352
- # 1. Define Cebra Dataset
1353
- InputData = cebra.data.TensorDataset(
1354
- torch.from_numpy(neural_data).type(torch.FloatTensor),
1355
- discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
1356
- ).to(device)
1357
-
1358
- # 2. Define Cebra Model
1359
- neural_model = cebra.models.init(
1360
- name="offset10-model",
1361
- num_neurons=InputData.input_dimension,
1362
- num_units=32,
1363
- num_output=2,
1364
- ).to(device)
1365
-
1366
- InputData.configure_for(neural_model)
1367
-
1368
- # 3. Define Loss Function Criterion and Optimizer
1369
- Crit = cebra.models.criterions.LearnableCosineInfoNCE(
1370
- temperature=0.001,
1371
- min_temperature=0.0001
1372
- ).to(device)
1373
-
1374
- Opt = torch.optim.Adam(
1375
- list(neural_model.parameters()) + list(Crit.parameters()),
1376
- lr=0.001,
1377
- weight_decay=0,
1378
- )
1379
-
1380
- # 4. Initialize Cebra Model
1381
- solver = cebra.solver.init(
1382
- name="single-session",
1383
- model=neural_model,
1384
- criterion=Crit,
1385
- optimizer=Opt,
1386
- tqdm_on=True,
1387
- ).to(device)
1388
-
1389
- # 5. Define Data Loader
1390
- loader = cebra.data.single_session.DiscreteDataLoader(
1391
- dataset=InputData, num_steps=10, batch_size=200, prior="uniform"
1392
- )
1393
-
1394
- # 6. Fit Model
1395
- solver.fit(loader=loader)
1396
-
1397
- # 7. Transform Embedding
1398
- TrainBatches = np.lib.stride_tricks.sliding_window_view(
1399
- neural_data, neural_model.get_offset().__len__(), axis=0
1400
- )
1401
-
1402
- X_train_emb = solver.transform(
1403
- torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device)
1404
- ).to(device)
1405
-
1406
- # 8. Plot Embedding
1407
- plot_embedding(
1408
- X_train_emb,
1409
- discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1410
- markersize=10,
1411
- )
1318
+ --share-model
0 commit comments