execve函数在当前进程的上下文中加载并运行一个新的程序。它会覆盖当前进程的地址空间,但并没有创建一个新的进程。新的程序仍然有与原程序相同的PID,并且继承了调用execve函数已打开的所有文件描述符,且会覆盖原程序的代码,出现问题。需要fork出一个新的子进程,在子进程中调用!!!
还有需要特别注意:在子线程中执行exit()会导致整个进程退出!!!
//key.cpp
//生成rsa公钥和私钥,并保存到文件public_key,private_key中
#include<gmp.h>
using namespace std;
#include<iostream>
#include<unistd.h>
int main(int args,char* argv[],char* envp[])
{
FILE* pub = nullptr, *pri = nullptr;
pub = fopen("./public_key","w+");
if(pub == nullptr)
{
exit(0);
}
pri = fopen("./private_key","w+");
if(pri == nullptr)
{
exit(0);
}
//设置默认随机状态为跟着时间改变,也就是引入时间随机数种子
gmp_randstate_t state;
gmp_randinit_default(state);
gmp_randseed_ui(state,time(0));
//创建p,q,f,n,d,e,并且初始化
mpz_t key_p,key_q,key_n,key_f,key_d,key_e;
mpz_inits(key_p,key_q,key_f,key_d,key_e,key_n,nullptr);
//随机生成1024位的随机数,128字节
mpz_urandomb(key_p,state,1024);
mpz_urandomb(key_q,state,1024);
//获取质数p,q
mpz_nextprime(key_p,key_p);
mpz_nextprime(key_q,key_q);
if(mpz_cmp(key_q,key_p) == 0)
{
exit(0);
}
//计算 n = p*q
mpz_mul(key_n,key_p,key_q);
//计算 f = (p-1)*(q-1)
mpz_sub_ui(key_p,key_p,1);
mpz_sub_ui(key_q,key_q,1);
mpz_mul(key_f,key_p,key_q);
//计算e,通常选择3、17、65537作为e的值
//使用这些值不会对RSA的安全性造成影响,因为解密数据还需要用到私钥
mpz_init_set_ui(key_e,65537);
//求d, d = (e^-1) mod ((p-1)*(q-1)),逆元运算
mpz_init(key_d);
mpz_invert(key_d,key_e,key_f);
//写入到文件中
mpz_out_str(pub,10,key_e);
fwrite("\n",1,1,pub);
mpz_out_str(pub,10,key_n);
//写入到文件中
mpz_out_str(pri,10,key_d);
fwrite("\n",1,1,pri);
mpz_out_str(pri,10,key_n);
fclose(pri);
fclose(pub);
mpz_clears(key_p,key_n,key_d,key_e,key_q,key_f,nullptr);
return 0;
}
//client.cpp
//客户端程序,模拟tcp连接,使用rsa加密
#include<iostream>
using namespace std;
#include<sys/socket.h>
#include<unistd.h>
#include<sys/types.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<mutex>
#include<thread>
#include<cstring>
#include<gmp.h>
#include<ctype.h>
#include<memory>
#include<sys/wait.h>
//线程安全的输出函数
mutex print_mutex{};
template<class...Args>
void print(const Args&...args)
{
lock_guard<mutex> m{print_mutex};
((cout<<args),...);
}
//发生错误函数
void unix_error(const char* msg)
{
print(msg);
exit(0);
}
//线程安全的输入函数
mutex input_mutex{};
template<class T>
void input(T& val)
{
lock_guard<mutex> m{input_mutex};
cin >> val;
}
//安全的socket函数
int Socket(int domain,int type,int protocol)
{
int fd = socket(domain,type,protocol);
if(fd < 0)
{
unix_error("create socket error\n");
}
return fd;
}
//安全的bind函数
void Bind(int sockfd ,const sockaddr* my_addr,socklen_t addrlen)
{
if(bind(sockfd,my_addr,addrlen) < 0)
{
unix_error("bind error\n");
}
}
//安全的listen函数
void Listen(int soc,int flag)
{
if(listen(soc,flag) < 0)
{
unix_error("listen error\n");
}
}
//安全的close函数
mutex close_mutex{};
void Close(int file)
{
lock_guard<mutex> m{close_mutex};
if(close(file) < 0)
{
print("close file ",file," error\n");
}
}
//安全的accept函数
int Accept(int sockfd,sockaddr* cliaddr,socklen_t* addrlen)
{
int fd = 0;
if((fd = accept(sockfd,cliaddr,addrlen)) < 0)
{
unix_error("accept error\n");
}
return fd;
}
//安全的inet_pton函数
void Inet_pton(int proto,const char* addr,void* dst)
{
if(inet_pton(proto,addr,dst) < 0)
{
unix_error("address translate error\n");
}
}
//安全的的inet_ntop函数
void Inet_ntop(int proto,void* src,char* dst,size_t len)
{
if(!inet_ntop(proto,src,dst,len))
{
print("address translate error\n");
}
}
//判断是否有除了数字之外的字符
bool is_all_digit(const char* str)
{
for(size_t i = 0;str[i] != '\0';++i)
{
if(!isdigit(str[i]))
{
return false;
}
}
return true;
}
int main(int args,char*argv[],char*envp[])
{
//生成公钥和私钥
if(fork() == 0)
{
char* argvs[] = {"./key",nullptr};
char* envps[] = {0,nullptr};
if(execve(argvs[0],argvs,envps) < 0)
{
unix_error("create key error\n");
}
exit(0);
}
//主进程阻塞等待子进程指向完成并且安全退出
int res = 0;
if(!wait(&res))
{
unix_error("create key error\n");
}
if(!WIFEXITED(res))
{
unix_error("create key error\n");
}
//读取公钥
mpz_t key_e,key_n,text;
mpz_inits(key_e,key_n,text,nullptr);
FILE* pub = fopen("./public_key","r");
if(pub == nullptr)
{
unix_error("open public key error\n");
}
if(mpz_inp_str(key_e,pub,10) <= 0)
{
unix_error("read public key error\n");
}
if(mpz_inp_str(key_n,pub,10) <= 0)
{
unix_error("read public key error\n");
}
//建立tcp连接
int fd = Socket(AF_INET,SOCK_STREAM,0);
sockaddr_in addr{};
addr.sin_port = htons(8088);
addr.sin_family = AF_INET;
Inet_pton(AF_INET,"127.0.0.1",&addr.sin_addr);
if(connect(fd,(sockaddr*)&addr,sizeof(addr)) < 0)
{
unix_error("连接失败......\n");
}
else
{
print("连接成功......\n");
}
//设置缓冲区
const size_t size = 1024;
char* buffer = new char[size];
unique_ptr<char> tmp{nullptr};
//通信
while(1)
{
memset(buffer,0,size);
print("请输入需要发送的信息(quit:退出):\n");
input(buffer);
if(strcmp(buffer,"quit") == 0)
{
print("exit successfully!\n");
break;
}
//rsa加密是针对数字而言,出现其他字符会出现错误
if(!is_all_digit(buffer))
{
print("输入的字符串中含有非数字字符,请重新输入。\n");
continue;
}
//str -> mpz
mpz_set_str(text,buffer,10);
//加密
mpz_powm(text,text,key_e,key_n);
//mpz -> str
tmp.reset(mpz_get_str(nullptr,10,text));
int res = send(fd,tmp.get(),strlen(tmp.get())+1,0);
if(res > 0)
{
print("发送成功......\n");
}
else if(res == 0)
{
print("连接断开......\n");
break;
}
else
{
print("发送错误......\n");
}
}
fclose(pub);
delete[]buffer;
mpz_clears(key_n,key_e,text,nullptr);
Close(fd);
return 0;
}
//server.cpp
//服务端程序
#include<iostream>
using namespace std;
#include<sys/socket.h>
#include<unistd.h>
#include<sys/types.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<mutex>
#include<thread>
#include<vector>
#include<cstring>
#include<gmp.h>
#include<memory>
//线程安全的输出函数
mutex print_mutex{};
template<class...Args>
void print(const Args&...args)
{
lock_guard<mutex> m{print_mutex};
((cout<<args),...);
}
//发生错误函数
void unix_error(const char* msg)
{
print(msg);
exit(0);
}
//线程安全的输入函数
mutex input_mutex{};
template<class T>
void input(T& val)
{
lock_guard<mutex> m{input_mutex};
cin >> val;
}
//安全的socket函数
int Socket(int domain,int type,int protocol)
{
int fd = socket(domain,type,protocol);
if(fd < 0)
{
unix_error("create socket error\n");
}
return fd;
}
//安全的bind函数
void Bind(int sockfd ,const sockaddr* my_addr,socklen_t addrlen)
{
if(bind(sockfd,my_addr,addrlen) < 0)
{
unix_error("bind error\n");
}
}
//安全的listen函数
void Listen(int soc,int flag)
{
if(listen(soc,flag) < 0)
{
unix_error("listen error\n");
}
}
//安全的close函数
mutex close_mutex{};
void Close(int file)
{
lock_guard<mutex> m{close_mutex};
if(close(file) < 0)
{
print("close file ",file," error\n");
}
}
//安全的accept函数
int Accept(int sockfd,sockaddr* cliaddr,socklen_t* addrlen)
{
int fd = 0;
if((fd = accept(sockfd,cliaddr,addrlen)) < 0)
{
unix_error("accept error\n");
}
return fd;
}
//安全的inet_pton函数
void Inet_pton(int proto,const char* addr,void* dst)
{
if(inet_pton(proto,addr,dst) < 0)
{
unix_error("address translate error\n");
}
}
//安全的的inet_ntop函数
void Inet_ntop(int proto,void* src,char* dst,size_t len)
{
if(!inet_ntop(proto,src,dst,len))
{
print("address translate error\n");
}
}
//与客户端通信线程入口函数,只收不发
void chat(int client,const string addr,int port)
{
mpz_t key_d,key_n,text;
mpz_inits(key_d,key_n,text,nullptr);
FILE* pri = fopen("./private_key","r");
if(pri == nullptr)
{
print("open private key error\n");
return;
}
if(mpz_inp_str(key_d,pri,10) <= 0)
{
print("read private key rror\n");
return;
}
if(mpz_inp_str(key_n,pri,10) <= 0)
{
print("read private key error\n");
return;
}
print(addr,":",port," connected\n");
const size_t size = 1024;
char* buffer = new char[size];
unique_ptr<char> tmp{nullptr};
int res = 0;
while(1)
{
memset(buffer,0,size);
res = recv(client,buffer,size,0);
if(res > 0)
{
mpz_set_str(text,buffer,10);
mpz_powm(text,text,key_d,key_n);
tmp.reset(mpz_get_str(nullptr,10,text));
print("receive message,size = ",strlen(tmp.get())," : ",tmp.get()," from ",addr,":",port,"\n");
}
else if(res == 0)
{
print(addr,":",port," disconnected\n");
break;
}
else
{
print(addr,":",port," error\n");
break;
}
}
Close(client);
mpz_clears(text,key_n,key_d,nullptr);
fclose(pri);
delete[]buffer;
}
int main(int args,char*argv[],char*envp[])
{
//建立并绑定端口,使用tcp连接
int soc = Socket(AF_INET,SOCK_STREAM,0);
sockaddr_in addr{};
memset(&addr, 0, sizeof(addr));
addr.sin_port = htons(8088);
addr.sin_family = AF_INET;
Inet_pton(AF_INET,"127.0.0.1",&addr.sin_addr);
Bind(soc,(sockaddr*)&addr,sizeof(addr));
Listen(soc,10);
//保存连接客户端的信息
sockaddr_in client;
socklen_t len = sizeof(client);
char client_addr[20];
int client_port = 0;
int client_socket = 0;
//每有一个连接就创建一个线程来进行通信
while(client_socket = Accept(soc,(sockaddr*)&client,&len))
{
memset(client_addr,0,sizeof(client_addr));
Inet_ntop(AF_INET,&client.sin_addr,client_addr,20);
client_port = ntohs(client.sin_port);
thread m(chat,client_socket,client_addr,client_port);
m.detach();
}
Close(soc);
return 0;
}
//makefile
//编译条件
key : key.cpp
g++ key.cpp -o key -w -std=c++2a -O2 -lgmp
clean :
rm -rf key server client
server : server.cpp
g++ server.cpp -o server -w -std=c++2a -O2 -lpthread -lgmp
client : client.cpp
g++ client.cpp -o client -w -std=c++2a -O2 -lpthread -lgmp
Comments NOTHING