13 #include <torch/torch.h>
14 #include <opencv2/opencv.hpp>
15 #include <torch/script.h>
23 using torch::nn::ConvTranspose2dImpl::ConvTranspose2dImpl;
25 torch::Tensor
forward(
const torch::Tensor& input) {
26 return torch::nn::ConvTranspose2dImpl::forward(input, c10::nullopt);
43 torch::nn::Sequential main;
45 Generator(
int nc_ = 3,
int nz_ = 100,
int ngf_ = 64) {
51 main = torch::nn::Sequential(
52 ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(nz,
ngf*8, 4).stride(1).padding(0).bias(
false)),
53 torch::nn::BatchNorm2d(
ngf*8),
54 torch::nn::Functional(torch::relu),
55 ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(
ngf*8,
ngf*4, 4).stride(2).padding(1).bias(
false)),
56 torch::nn::BatchNorm2d(
ngf*4),
57 torch::nn::Functional(torch::relu),
58 ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(
ngf*4,
ngf*2, 4).stride(2).padding(1).bias(
false)),
59 torch::nn::BatchNorm2d(
ngf*2),
60 torch::nn::Functional(torch::relu),
61 ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(
ngf*2,
ngf, 4).stride(2).padding(1).bias(
false)),
62 torch::nn::BatchNorm2d(
ngf),
63 torch::nn::Functional(torch::relu),
64 ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(
ngf, nc, 4).stride(2).padding(1).bias(
false)),
65 torch::nn::Functional(torch::tanh)
86 torch::nn::Sequential main;
92 main = torch::nn::Sequential(
93 torch::nn::Conv2d(torch::nn::Conv2dOptions(nc,
ndf, 4).stride(2).padding(1).bias(
false)),
94 torch::nn::Functional(torch::leaky_relu, 0.2),
95 torch::nn::Conv2d(torch::nn::Conv2dOptions(
ndf,
ndf*2, 4).stride(2).padding(1).bias(
false)),
96 torch::nn::BatchNorm2d(
ndf*2),
97 torch::nn::Functional(torch::leaky_relu, 0.2),
98 torch::nn::Conv2d(torch::nn::Conv2dOptions(
ndf*2,
ndf*4, 4).stride(2).padding(1).bias(
false)),
99 torch::nn::BatchNorm2d(
ndf*4),
100 torch::nn::Functional(torch::leaky_relu, 0.2),
101 torch::nn::Conv2d(torch::nn::Conv2dOptions(
ndf*4,
ndf*8, 4).stride(2).padding(1).bias(
false)),
102 torch::nn::BatchNorm2d(
ndf*8),
103 torch::nn::Functional(torch::leaky_relu, 0.2),
104 torch::nn::Conv2d(torch::nn::Conv2dOptions(
ndf*8, 1, 4).stride(1).padding(0).bias(
false)),
105 torch::nn::Functional(torch::sigmoid)