DCGAN on CelebA Dataset using Libtorch (PyTorch C++ Frontend API)
network.hpp
Go to the documentation of this file.
1 //
2 // network.hpp
3 // DCGAN
4 //
5 // Created by Kushashwa Ravi Shrimali on 23/08/19.
6 // Copyright © 2019 Kushashwa Ravi Shrimali. All rights reserved.
7 //
8 
9 #ifndef network_hpp
10 #define network_hpp
11 
12 #include <iostream>
13 #include <torch/torch.h>
14 #include <opencv2/opencv.hpp>
15 #include <torch/script.h>
16 
17 // DCGAN uses convolutional in Discriminator
18 // And convolutional-transpose layers in Generator
19 
22 struct ConvTranspose2dWrapperImpl : public torch::nn::ConvTranspose2dImpl {
23  using torch::nn::ConvTranspose2dImpl::ConvTranspose2dImpl;
24 
25  torch::Tensor forward(const torch::Tensor& input) {
26  return torch::nn::ConvTranspose2dImpl::forward(input, c10::nullopt);
27  }
28 };
29 
30 TORCH_MODULE(ConvTranspose2dWrapper);
31 
32 
37 class Generator : public torch::nn::Module {
38  private:
39  int nc;
40  int nz;
41  int ngf;
42  int ndf;
43  torch::nn::Sequential main;
44  public:
45  Generator(int nc_ = 3, int nz_ = 100, int ngf_ = 64) {
46  nc = nc_;
47  nz = nz_;
48  ngf = ngf_;
49 
50  // TODO: ConvTranspose2dWrapper will be replaced with ConvTranspose2d in Libtorch 1.5 Version
51  main = torch::nn::Sequential(
52  ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(nz, ngf*8, 4).stride(1).padding(0).bias(false)),
53  torch::nn::BatchNorm2d(ngf*8),
54  torch::nn::Functional(torch::relu),
55  ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(ngf*8, ngf*4, 4).stride(2).padding(1).bias(false)),
56  torch::nn::BatchNorm2d(ngf*4),
57  torch::nn::Functional(torch::relu),
58  ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(ngf*4, ngf*2, 4).stride(2).padding(1).bias(false)),
59  torch::nn::BatchNorm2d(ngf*2),
60  torch::nn::Functional(torch::relu),
61  ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(ngf*2, ngf, 4).stride(2).padding(1).bias(false)),
62  torch::nn::BatchNorm2d(ngf),
63  torch::nn::Functional(torch::relu),
64  ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(ngf, nc, 4).stride(2).padding(1).bias(false)),
65  torch::nn::Functional(torch::tanh)
66  );
67  }
68 
69  torch::nn::Sequential get_module() {
70  return main;
71  }
72 };
73 
80 class Discriminator : public torch::nn::Module {
81  private:
82  int nc;
83  int nz;
84  int ngf;
85  int ndf;
86  torch::nn::Sequential main;
87  public:
88  Discriminator(int nc_ = 3, int ngf_ = 64, int ndf_ = 64) {
89  nc = nc_;
90  ndf = ndf_;
91 
92  main = torch::nn::Sequential(
93  torch::nn::Conv2d(torch::nn::Conv2dOptions(nc, ndf, 4).stride(2).padding(1).bias(false)),
94  torch::nn::Functional(torch::leaky_relu, 0.2),
95  torch::nn::Conv2d(torch::nn::Conv2dOptions(ndf, ndf*2, 4).stride(2).padding(1).bias(false)),
96  torch::nn::BatchNorm2d(ndf*2),
97  torch::nn::Functional(torch::leaky_relu, 0.2),
98  torch::nn::Conv2d(torch::nn::Conv2dOptions(ndf*2, ndf*4, 4).stride(2).padding(1).bias(false)),
99  torch::nn::BatchNorm2d(ndf*4),
100  torch::nn::Functional(torch::leaky_relu, 0.2),
101  torch::nn::Conv2d(torch::nn::Conv2dOptions(ndf*4, ndf*8, 4).stride(2).padding(1).bias(false)),
102  torch::nn::BatchNorm2d(ndf*8),
103  torch::nn::Functional(torch::leaky_relu, 0.2),
104  torch::nn::Conv2d(torch::nn::Conv2dOptions(ndf*8, 1, 4).stride(1).padding(0).bias(false)),
105  torch::nn::Functional(torch::sigmoid)
106  );
107  }
108 
109  torch::nn::Sequential get_module() {
110  return main;
111  }
112 };
113 
114 #endif /* network_hpp */
Discriminator
Definition: network.hpp:80
ConvTranspose2dWrapperImpl::forward
torch::Tensor forward(const torch::Tensor &input)
Definition: network.hpp:25
Generator::get_module
torch::nn::Sequential get_module()
Definition: network.hpp:69
TORCH_MODULE
TORCH_MODULE(ConvTranspose2dWrapper)
Discriminator::get_module
torch::nn::Sequential get_module()
Definition: network.hpp:109
Discriminator::Discriminator
Discriminator(int nc_=3, int ngf_=64, int ndf_=64)
Definition: network.hpp:88
ConvTranspose2dWrapperImpl
Definition: network.hpp:22
main
int main(int argc, const char *argv[])
Definition: main.cpp:49
Generator::Generator
Generator(int nc_=3, int nz_=100, int ngf_=64)
Definition: network.hpp:45
ngf
int ngf
Definition: main.cpp:46
ndf
int ndf
Definition: main.cpp:47
Generator
Definition: network.hpp:37