13 #include <opencv2/opencv.hpp>
14 #include <torch/torch.h>
16 #include <torch/script.h>
20 torch::Tensor
read_data(std::string location,
int resize);
28 std::vector<torch::Tensor>
process_images(std::vector<std::string> list_images,
int resize);
32 std::vector<torch::Tensor>
process_labels(std::vector<int> list_labels);
37 std::pair<std::vector<std::string>, std::vector<int>>
load_data_from_folder(std::vector<std::string> folders_name);
41 template<
typename Dataloader>
42 void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, \
43 torch::optim::Optimizer& optimizer,
size_t dataset_size);
47 template<
typename Dataloader>
48 void test(torch::jit::script::Module network, torch::nn::Linear lin, Dataloader& loader,
size_t data_size);
63 std::vector<torch::Tensor> states, labels;
66 CustomDataset(std::vector<std::string> list_images, std::vector<int> list_labels,
int resize=224) {
69 ds_size = states.size();
77 torch::data::Example<>
get(
size_t index)
override {
78 torch::Tensor sample_img = states.at(index);
79 torch::Tensor sample_label = labels.at(index);
80 return {sample_img.clone(), sample_label.clone()};
90 cv::Mat* img_varray =
new cv::Mat[grid_size*grid_size];
91 for(
int i = 0; i < grid_size * grid_size; i++) {
92 torch::Tensor out_tensor =
get(i).data.squeeze().detach().permute({1, 2, 0});
93 out_tensor = out_tensor.clamp(0, 255).to(torch::kCPU).to(torch::kU8);
94 *(img_varray + i) = cv::Mat::eye(out_tensor.sizes()[0], out_tensor.sizes()[1], CV_8UC3);
95 std::memcpy((img_varray + i)->data, out_tensor.data_ptr(),
sizeof(torch::kU8) * out_tensor.numel());
97 cv::Mat out(256, 256, CV_8UC3);
98 cv::Mat temp_out(256, 256, CV_8UC3);
99 for(
int vconcat_times = 0; vconcat_times < grid_size; vconcat_times++) {
100 cv::cvtColor(*(img_varray + vconcat_times*grid_size), out, cv::COLOR_BGR2RGB);
101 for(
int hconcat_times = vconcat_times*grid_size; hconcat_times < (vconcat_times+1)*grid_size - 1; hconcat_times++) {
102 cv::cvtColor(*(img_varray + hconcat_times + 1), *(img_varray + hconcat_times + 1), cv::COLOR_BGR2RGB);
103 cv::hconcat(out, *(img_varray + hconcat_times + 1), out);
105 if(vconcat_times == 0)
108 cv::vconcat(temp_out, out, temp_out);
111 cv::cvtColor(temp_out, temp_out, cv::COLOR_BGR2RGB);
112 cv::imwrite(
"out.jpg", temp_out);
113 std::cout <<
"Image saved as out.jpg" << std::endl;
119 torch::Tensor out_tensor_ =
get(index).data.squeeze().detach().permute({1, 2, 0});
120 out_tensor_ = out_tensor_.clamp(0, 255).to(torch::kCPU).to(torch::kU8);
121 cv::Mat sample_img(out_tensor_.sizes()[0], out_tensor_.sizes()[1], CV_8UC3);
122 std::memcpy(sample_img.data, out_tensor_.data_ptr(),
sizeof(torch::kU8) * out_tensor_.numel());
123 cv::imwrite(
"sample.jpg", sample_img);
124 std::cout <<
"Image saved as sample.jpg" << std::endl;
129 torch::optional<size_t>
size()
const override {