DsC*_*Cpp 6 pytorch dataloader
我需要BatchSampler
在 pytorch 中使用 a DataLoader
,而不是__getitem__
多次调用数据集(远程数据集,每个查询都很昂贵)。
我无法理解如何将批量采样器与任何给定的数据集一起使用。
例如
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf),
batch_sampler=BatchSampler(Sampler(), batch_size=3))
Run Code Online (Sandbox Code Playgroud)
我不明白的是,在网上或火炬文档中都没有找到任何示例,我不明白的是如何使用我的get_batch
函数而不是 __getitem__ 函数。
编辑:根据 Szymon Maszke 的回答,这就是我尝试过的,但\_\_get_item__
每次调用都会获取一个索引,而不是大小列表batch_size
class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx): ------> here I get only one index
return self.wiki_df.loc[batch_idx]
loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)
Run Code Online (Sandbox Code Playgroud)
你不能使用get_batch
代替__getitem__
,我不认为这样做有什么意义。
torch.utils.data.BatchSampler
Sampler()
从您的实例中获取索引(在本例中3
是它们)并将其返回,以便list
可以在您的方法中使用它们MyDataset
__getitem__
(检查源代码,大多数采样器和数据相关实用程序都很容易遵循,以防您需要)。
我假设您self.ddf
支持列表切片(例如,self.ddf[[25, 44, 115]]
正确返回值并且仅使用一次昂贵的调用)。在这种情况下,只需切换get_batch
到__getitem__
即可。
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, batch_idx):
return self.ddf[batch_idx] -> batch_idx is a list
Run Code Online (Sandbox Code Playgroud)
编辑:您必须指定batch_sampler
as sampler
,否则批次将分为单个索引。这应该没问题:
loader = DataLoader(
dataset=dataset,
# This line below!
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
),
num_workers=self.hparams.num_data_workers,
)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
22302 次 |
最近记录: |