StudyAi
首页
课程
专栏
小组
问答
招聘
登录
注册
江户川柯北
真相只有一个!(#^.^#)
12509
访客
65
文章
219311
字数
3
粉丝
私信
关注
访问
Friendly Introduction
(广告:~~ 小兰的AI专栏 ~~)
Pytorch DataSet 之 Sampler
5个月前
4977字
330阅读
0评论
Pytorch dataset 之 Sampler Sampler 通过制定 索引生成规则,即index的产生方式,来控制后续的数据读取。 ``` from torch.utils.data.sampler import Sampler ``` # 复杂版本 ``` class sampler(Sampler): def __init__(self, train_size, batch_size): num_data = train_size self.num_per_batch = int(num_data / batch_size) self.batch_size = batch_size self.range = torch.arange(0, batch_size).view(1, batch_size).long() self.leftover_flag = False if num_data % batch_size: self.leftover = torch.arange(self.num_per_batch * batch_size, num_data).long() self.leftover_flag = True def __iter__(self): rand_num = torch.randperm(self.num_per_batch).view(-1, 1) * self.batch_size # rand_num = torch.arange(self.num_per_batch).long().view(-1, 1) * self.batch_size self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.range self.rand_num_view = self.rand_num.view(-1) if self.leftover_flag: self.rand_num_view = torch.cat((self.rand_num_view, self.leftover), 0) return iter(self.rand_num_view) def __len__(self): return self.num_data ``` # 简化版本 ``` class xSampler(Sampler): def __init__(self, data_size, batch_size, leftover_keep=True, sampling_type='batch_random'): self.data_size = data_size self.batch_size = batch_size self.batch_nums = int(data_size / batch_size) self.leftover_keep = leftover_keep self.sampling_type = sampling_type # batch_random, all_random, no_random def __iter__(self): if self.sampling_type == 'batch_random': # 各个batch读取随机 但各个batch内图片的index不变 # batch_index随机重置,但每个batch内的image_index不变 rand_batch = torch.randperm(self.batch_nums).view(-1, 1).expand(self.batch_nums, self.batch_size) rand_batch = rand_batch * self.batch_size + torch.arange(0, self.batch_size).view(1, -1).long() rand_batch = rand_batch.view(-1) if self.leftover_keep and self.data_size % self.batch_size: leftover = torch.arange(self.batch_nums * self.batch_size, self.data_size).long() rand_batch = torch.cat([rand_batch, leftover], dim=0) return iter(rand_batch) elif self.sampling_type == 'all_random': # 所有图片全部随机重置 rand_all = torch.randperm(self.data_size).long().view(-1) return iter(rand_all) elif self.sampling_type == 'no_random': # 按原顺序返回所有图片 no_rand = torch.arange(0, self.data_size).long().view(-1) return iter(no_rand) else: raise ValueError('Unknown Parameters %s.' % self.sampling_type) def __len__(self): if self.leftover_keep: return self.data_size else: return self.batch_size * self.batch_nums ``` # batch random ``` xsampler = xSampler(20, 3, True) list(xsampler) Out[81]: [15, 16, 17, 6, 7, 8, 12, 13, 14, 0, 1, 2, 3, 4, 5, 9, 10, 11, 18, 19] list(xsampler) Out[82]: [3, 4, 5, 15, 16, 17, 0, 1, 2, 6, 7, 8, 12, 13, 14, 9, 10, 11, 18, 19] list(xsampler) Out[83]: [3, 4, 5, 12, 13, 14, 9, 10, 11, 0, 1, 2, 6, 7, 8, 15, 16, 17, 18, 19] list(xsampler) Out[84]: [15, 16, 17, 0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 12, 13, 14, 18, 19] list(xsampler) Out[85]: [3, 4, 5, 9, 10, 11, 15, 16, 17, 6, 7, 8, 0, 1, 2, 12, 13, 14, 18, 19] ``` # all random ``` xsampler = xSampler(20,3,True,'all_random') list(xsampler) Out[130]: [14, 4, 16, 11, 12, 0, 10, 9, 5, 1, 13, 6, 3, 15, 19, 18, 2, 8, 17, 7] list(xsampler) Out[131]: [8, 16, 14, 5, 18, 15, 2, 1, 3, 6, 4, 0, 9, 7, 17, 12, 19, 10, 11, 13] list(xsampler) Out[132]: [4, 0, 9, 15, 5, 17, 12, 8, 2, 6, 14, 19, 7, 1, 16, 11, 13, 10, 18, 3] ``` # no random ``` xsampler = xSampler(20,3,True,'no_random') list(xsampler) Out[127]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] list(xsampler) Out[128]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] ``` # 使用next输出 ``` xsampler Out[86]: <__main__.xSampler at 0x69997f0> next(xsampler) Traceback (most recent call last): File "D:\Miniconda2\envs\tch\lib\site-packages\IPython\core\interactiveshell.py", line 2910, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "
", line 1, in
next(xsampler) TypeError: 'xSampler' object is not an iterator ``` ``` xsampler = xSampler(20,3,True) [next(iter(xsampler)) for _ in range(len(xsampler))] Out[104]: [9, 12, 3, 12, 12, 12, 6, 0, 9, 0, 15, 12, 3, 9, 6, 15, 6, 6, 3, 15] [next(iter(xsampler)) for _ in range(len(xsampler))] Out[105]: [0, 0, 0, 3, 0, 12, 0, 15, 15, 15, 3, 0, 3, 6, 0, 12, 15, 15, 6, 15] [next(iter(xsampler)) for _ in range(len(xsampler))] Out[106]: [3, 6, 15, 9, 15, 9, 6, 6, 9, 3, 12, 12, 15, 0, 3, 6, 6, 0, 9, 6] ```
收藏 (0)
打赏
点赞 (0)
热门评论
到顶端
发布^_^
powerd by studyai.com 2017