Skip to content

Commit e8f73fe

Browse files
committed
undo changes in usage.rst
1 parent c82695e commit e8f73fe

File tree

1 file changed

+1
-94
lines changed

1 file changed

+1
-94
lines changed

docs/source/usage.rst

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,97 +1315,4 @@ Below is the documentation on the available arguments.
13151315
Interval of validation in training
13161316
--train-ratio 0.8 Ratio of train dataset. The remaining will be used for valid and test split.
13171317
--valid-ratio 0.1 Ratio of validation set after the train data split. The remaining will be test split
1318-
--share-model
1319-
1320-
Model initialization using the Torch API
1321-
----------------------------------------
1322-
1323-
The scikit-learn API provides parametrization to many common use cases.
1324-
The Torch API however allows for more flexibility and customization, for e.g.
1325-
sampling, criterions, and data loaders.
1326-
1327-
In this minimal example we show how to initialize a CEBRA model using the Torch API.
1328-
Here the :py:class:`cebra.data.single_session.DiscreteDataLoader`
1329-
gets initilized which also allows the `prior` to be directly parametrized.
1330-
1331-
👉 For an example notebook using the Torch API check out the :doc:`demo_notebooks/Demo_Allen`.
1332-
1333-
1334-
.. testcode::
1335-
1336-
import numpy as np
1337-
import cebra.datasets
1338-
from cebra import plot_embedding
1339-
import torch
1340-
1341-
if torch.cuda.is_available():
1342-
device = "cuda"
1343-
else:
1344-
device = "cpu"
1345-
1346-
neural_data = cebra.load_data(file="neural_data.npz", key="neural")
1347-
1348-
discrete_label = cebra.load_data(
1349-
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
1350-
)
1351-
1352-
# 1. Define Cebra Dataset
1353-
InputData = cebra.data.TensorDataset(
1354-
torch.from_numpy(neural_data).type(torch.FloatTensor),
1355-
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
1356-
).to(device)
1357-
1358-
# 2. Define Cebra Model
1359-
neural_model = cebra.models.init(
1360-
name="offset10-model",
1361-
num_neurons=InputData.input_dimension,
1362-
num_units=32,
1363-
num_output=2,
1364-
).to(device)
1365-
1366-
InputData.configure_for(neural_model)
1367-
1368-
# 3. Define Loss Function Criterion and Optimizer
1369-
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
1370-
temperature=0.001,
1371-
min_temperature=0.0001
1372-
).to(device)
1373-
1374-
Opt = torch.optim.Adam(
1375-
list(neural_model.parameters()) + list(Crit.parameters()),
1376-
lr=0.001,
1377-
weight_decay=0,
1378-
)
1379-
1380-
# 4. Initialize Cebra Model
1381-
solver = cebra.solver.init(
1382-
name="single-session",
1383-
model=neural_model,
1384-
criterion=Crit,
1385-
optimizer=Opt,
1386-
tqdm_on=True,
1387-
).to(device)
1388-
1389-
# 5. Define Data Loader
1390-
loader = cebra.data.single_session.DiscreteDataLoader(
1391-
dataset=InputData, num_steps=10, batch_size=200, prior="uniform"
1392-
)
1393-
1394-
# 6. Fit Model
1395-
solver.fit(loader=loader)
1396-
1397-
# 7. Transform Embedding
1398-
TrainBatches = np.lib.stride_tricks.sliding_window_view(
1399-
neural_data, neural_model.get_offset().__len__(), axis=0
1400-
)
1401-
1402-
X_train_emb = solver.transform(
1403-
torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to(device)
1404-
).to(device)
1405-
1406-
# 8. Plot Embedding
1407-
plot_embedding(
1408-
X_train_emb,
1409-
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
1410-
markersize=10,
1411-
)
1318+
--share-model

0 commit comments

Comments
 (0)