DCGAN on CelebA Dataset using Libtorch (PyTorch C++ Frontend API)
dataset.hpp
Go to the documentation of this file.
1 //
2 // dataset.hpp
3 // DCGAN
4 //
5 // Created by Kushashwa Ravi Shrimali on 21/08/19.
6 // Copyright © 2019 Kushashwa Ravi Shrimali. All rights reserved.
7 //
8 
9 #ifndef dataset_hpp
10 #define dataset_hpp
11 
12 #include <iostream>
13 #include <opencv2/opencv.hpp>
14 #include <torch/torch.h>
15 #include <dirent.h>
16 #include <torch/script.h>
17 
20 torch::Tensor read_data(std::string location, int resize);
21 
24 torch::Tensor read_label(int label);
25 
28 std::vector<torch::Tensor> process_images(std::vector<std::string> list_images, int resize);
29 
32 std::vector<torch::Tensor> process_labels(std::vector<int> list_labels);
33 
37 std::pair<std::vector<std::string>, std::vector<int>> load_data_from_folder(std::vector<std::string> folders_name);
38 
41 template<typename Dataloader>
42 void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, \
43 torch::optim::Optimizer& optimizer, size_t dataset_size);
44 
47 template<typename Dataloader>
48 void test(torch::jit::script::Module network, torch::nn::Linear lin, Dataloader& loader, size_t data_size);
49 
60 class CustomDataset : public torch::data::Dataset<CustomDataset> {
61 private:
62  /* data */
63  std::vector<torch::Tensor> states, labels;
64  size_t ds_size;
65 public:
66  CustomDataset(std::vector<std::string> list_images, std::vector<int> list_labels, int resize=224) {
67  states = process_images(list_images, resize);
68  labels = process_labels(list_labels);
69  ds_size = states.size();
70  };
71 
77  torch::data::Example<> get(size_t index) override {
78  torch::Tensor sample_img = states.at(index);
79  torch::Tensor sample_label = labels.at(index);
80  return {sample_img.clone(), sample_label.clone()};
81  };
82 
89  void show_batch(int grid_size = 3) {
90  cv::Mat* img_varray = new cv::Mat[grid_size*grid_size];
91  for(int i = 0; i < grid_size * grid_size; i++) {
92  torch::Tensor out_tensor = get(i).data.squeeze().detach().permute({1, 2, 0});
93  out_tensor = out_tensor.clamp(0, 255).to(torch::kCPU).to(torch::kU8);
94  *(img_varray + i) = cv::Mat::eye(out_tensor.sizes()[0], out_tensor.sizes()[1], CV_8UC3);
95  std::memcpy((img_varray + i)->data, out_tensor.data_ptr(), sizeof(torch::kU8) * out_tensor.numel());
96  }
97  cv::Mat out(256, 256, CV_8UC3);
98  cv::Mat temp_out(256, 256, CV_8UC3);
99  for(int vconcat_times = 0; vconcat_times < grid_size; vconcat_times++) {
100  cv::cvtColor(*(img_varray + vconcat_times*grid_size), out, cv::COLOR_BGR2RGB);
101  for(int hconcat_times = vconcat_times*grid_size; hconcat_times < (vconcat_times+1)*grid_size - 1; hconcat_times++) {
102  cv::cvtColor(*(img_varray + hconcat_times + 1), *(img_varray + hconcat_times + 1), cv::COLOR_BGR2RGB);
103  cv::hconcat(out, *(img_varray + hconcat_times + 1), out);
104  }
105  if(vconcat_times == 0)
106  temp_out = out;
107  else {
108  cv::vconcat(temp_out, out, temp_out);
109  }
110  }
111  cv::cvtColor(temp_out, temp_out, cv::COLOR_BGR2RGB);
112  cv::imwrite("out.jpg", temp_out);
113  std::cout << "Image saved as out.jpg" << std::endl;
114  }
115 
118  void show_sample(int index) {
119  torch::Tensor out_tensor_ = get(index).data.squeeze().detach().permute({1, 2, 0});
120  out_tensor_ = out_tensor_.clamp(0, 255).to(torch::kCPU).to(torch::kU8);
121  cv::Mat sample_img(out_tensor_.sizes()[0], out_tensor_.sizes()[1], CV_8UC3);
122  std::memcpy(sample_img.data, out_tensor_.data_ptr(), sizeof(torch::kU8) * out_tensor_.numel());
123  cv::imwrite("sample.jpg", sample_img);
124  std::cout << "Image saved as sample.jpg" << std::endl;
125  }
126 
129  torch::optional<size_t> size() const override {
130  return int(ds_size);
131  };
132 };
133 
134 #endif /* dataset_hpp */
test
void test(torch::jit::script::Module network, torch::nn::Linear lin, Dataloader &loader, size_t data_size)
Function to test the network on test data.
CustomDataset
This class allows loading a Custom Dataset in Libtorch.
Definition: dataset.hpp:60
process_labels
std::vector< torch::Tensor > process_labels(std::vector< int > list_labels)
Function returns vector of tensors (labels) read from the list of labels.
Definition: dataset.cpp:55
CustomDataset::CustomDataset
CustomDataset(std::vector< std::string > list_images, std::vector< int > list_labels, int resize=224)
Definition: dataset.hpp:66
CustomDataset::get
torch::data::Example get(size_t index) override
Definition: dataset.hpp:77
read_data
torch::Tensor read_data(std::string location, int resize)
Function to return image read at location given as type torch::Tensor.
Definition: dataset.cpp:17
load_data_from_folder
std::pair< std::vector< std::string >, std::vector< int > > load_data_from_folder(std::vector< std::string > folders_name)
Function to load data from given folder(s) name(s) (folders_name) Returns pair of vectors of string (...
Definition: dataset.cpp:69
CustomDataset::show_batch
void show_batch(int grid_size=3)
Visualize batch of data (by default 3x3)
Definition: dataset.hpp:89
CustomDataset::size
torch::optional< size_t > size() const override
Definition: dataset.hpp:129
train
void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader &data_loader, torch::optim::Optimizer &optimizer, size_t dataset_size)
Function to train the network on train data.
read_label
torch::Tensor read_label(int label)
Function to return label from int (0, 1 for binary and 0, 1, ..., n-1 for n-class classification) as ...
Definition: dataset.cpp:30
process_images
std::vector< torch::Tensor > process_images(std::vector< std::string > list_images, int resize)
Function returns vector of tensors (images) read from the list of images in a folder.
Definition: dataset.cpp:41
CustomDataset::show_sample
void show_sample(int index)
Definition: dataset.hpp:118