RSA算法和TCP连接模拟https

Aki 发布于 2023-02-05 277 次阅读


execve函数在当前进程的上下文中加载并运行一个新的程序。它会覆盖当前进程的地址空间,但并没有创建一个新的进程。新的程序仍然有与原程序相同的PID,并且继承了调用execve函数已打开的所有文件描述符,且会覆盖原程序的代码,出现问题。需要fork出一个新的子进程,在子进程中调用!!!

还有需要特别注意:在子线程中执行exit()会导致整个进程退出!!!

//key.cpp
//生成rsa公钥和私钥,并保存到文件public_key,private_key中

#include<gmp.h>
using namespace std;
#include<iostream>
#include<unistd.h>


int main(int args,char* argv[],char* envp[])
{

	FILE* pub = nullptr, *pri = nullptr;

	pub = fopen("./public_key","w+");
	if(pub == nullptr)
	{
		exit(0);
	}

	pri = fopen("./private_key","w+");
	if(pri == nullptr)
	{
		exit(0);
	}

	//设置默认随机状态为跟着时间改变,也就是引入时间随机数种子
	gmp_randstate_t state;
	gmp_randinit_default(state);
	gmp_randseed_ui(state,time(0));

	//创建p,q,f,n,d,e,并且初始化
	mpz_t key_p,key_q,key_n,key_f,key_d,key_e;
	mpz_inits(key_p,key_q,key_f,key_d,key_e,key_n,nullptr);

	//随机生成1024位的随机数,128字节
	mpz_urandomb(key_p,state,1024);
	mpz_urandomb(key_q,state,1024);


	//获取质数p,q
	mpz_nextprime(key_p,key_p);
	mpz_nextprime(key_q,key_q);
	if(mpz_cmp(key_q,key_p) == 0)
	{
		exit(0);
	}


	//计算 n = p*q
	mpz_mul(key_n,key_p,key_q);


	//计算 f = (p-1)*(q-1)
	mpz_sub_ui(key_p,key_p,1);
	mpz_sub_ui(key_q,key_q,1);
	mpz_mul(key_f,key_p,key_q);



	//计算e,通常选择3、17、65537作为e的值
	//使用这些值不会对RSA的安全性造成影响,因为解密数据还需要用到私钥
	mpz_init_set_ui(key_e,65537);



	//求d, d = (e^-1) mod ((p-1)*(q-1)),逆元运算
	mpz_init(key_d);
	mpz_invert(key_d,key_e,key_f);


	//写入到文件中
	mpz_out_str(pub,10,key_e);
	fwrite("\n",1,1,pub);
	mpz_out_str(pub,10,key_n);


        //写入到文件中
	mpz_out_str(pri,10,key_d);
	fwrite("\n",1,1,pri);
	mpz_out_str(pri,10,key_n);

	fclose(pri);
	fclose(pub);
	mpz_clears(key_p,key_n,key_d,key_e,key_q,key_f,nullptr);

	return 0;
}
//client.cpp
//客户端程序,模拟tcp连接,使用rsa加密

#include<iostream>
using namespace std;
#include<sys/socket.h>
#include<unistd.h>
#include<sys/types.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<mutex>
#include<thread>
#include<cstring>
#include<gmp.h>
#include<ctype.h>
#include<memory>
#include<sys/wait.h>



//线程安全的输出函数
mutex print_mutex{};
template<class...Args>
void print(const Args&...args)
{
	lock_guard<mutex> m{print_mutex};
	((cout<<args),...);
}


//发生错误函数
void unix_error(const char* msg)
{
	print(msg);
	exit(0);
}


//线程安全的输入函数
mutex input_mutex{};
template<class T>
void input(T& val)
{
	lock_guard<mutex> m{input_mutex};
	cin >> val;
}


//安全的socket函数
int Socket(int domain,int type,int protocol)
{
	int fd = socket(domain,type,protocol);
	if(fd < 0)
	{
		unix_error("create socket error\n");
	}
	return fd;
}

//安全的bind函数
void Bind(int sockfd ,const sockaddr* my_addr,socklen_t addrlen)
{
	if(bind(sockfd,my_addr,addrlen) < 0)
	{
		unix_error("bind error\n");
	}
}


//安全的listen函数
void Listen(int soc,int flag)
{
	if(listen(soc,flag) < 0)
	{
		unix_error("listen error\n");
	}
}
	
//安全的close函数
mutex close_mutex{};
void Close(int file)
{
	lock_guard<mutex> m{close_mutex};
	if(close(file) < 0)
	{
		print("close file ",file," error\n");
	}
}

//安全的accept函数
int Accept(int sockfd,sockaddr* cliaddr,socklen_t* addrlen)
{
	int fd = 0;
	if((fd = accept(sockfd,cliaddr,addrlen)) < 0)
	{
		unix_error("accept error\n");
	}
	return fd;
}


//安全的inet_pton函数
void Inet_pton(int proto,const char* addr,void* dst)
{
	if(inet_pton(proto,addr,dst) < 0)
	{
		unix_error("address translate error\n");
	}
}



//安全的的inet_ntop函数
void Inet_ntop(int proto,void* src,char* dst,size_t len)
{
	if(!inet_ntop(proto,src,dst,len))
	{
		print("address translate error\n");
	}
}


//判断是否有除了数字之外的字符
bool is_all_digit(const char* str)
{
	for(size_t i = 0;str[i] != '\0';++i)
	{
		if(!isdigit(str[i]))
		{
			return false;
		}
	}
	return true;
}



int main(int args,char*argv[],char*envp[])
{

	//生成公钥和私钥
	if(fork() == 0)
	{
		char* argvs[] = {"./key",nullptr};
		char* envps[] = {0,nullptr};
		if(execve(argvs[0],argvs,envps) < 0)
		{
			unix_error("create key error\n");
		}
		exit(0);
	}

	//主进程阻塞等待子进程指向完成并且安全退出
	int res = 0;
	if(!wait(&res))
	{
		unix_error("create key error\n");
	}

	if(!WIFEXITED(res))
	{
		unix_error("create key error\n");
	}


	//读取公钥
	mpz_t key_e,key_n,text;
	mpz_inits(key_e,key_n,text,nullptr);
	FILE* pub = fopen("./public_key","r");
	if(pub == nullptr)
	{
		unix_error("open public key error\n");
	}

	if(mpz_inp_str(key_e,pub,10) <= 0)
	{
		unix_error("read public key error\n");
	}

	if(mpz_inp_str(key_n,pub,10) <= 0)
	{
		unix_error("read public key error\n");
	}


	//建立tcp连接
	int fd = Socket(AF_INET,SOCK_STREAM,0);

	sockaddr_in addr{};
	addr.sin_port = htons(8088);
	addr.sin_family = AF_INET;
	Inet_pton(AF_INET,"127.0.0.1",&addr.sin_addr);

	if(connect(fd,(sockaddr*)&addr,sizeof(addr)) < 0)
	{
		unix_error("连接失败......\n");
	}
	else
	{
		print("连接成功......\n");
	}

	//设置缓冲区
	const size_t size = 1024;
	char* buffer = new char[size];
	unique_ptr<char> tmp{nullptr};


	//通信
	while(1)
	{
		memset(buffer,0,size);
		print("请输入需要发送的信息(quit:退出):\n");
		input(buffer);

		if(strcmp(buffer,"quit") == 0)
		{
			print("exit successfully!\n");
			break;
		}

		//rsa加密是针对数字而言,出现其他字符会出现错误
		if(!is_all_digit(buffer))
		{
			print("输入的字符串中含有非数字字符,请重新输入。\n");
			continue;
		}

		//str -> mpz
		mpz_set_str(text,buffer,10);
		//加密
		mpz_powm(text,text,key_e,key_n);
		//mpz -> str
		tmp.reset(mpz_get_str(nullptr,10,text));
	
		int res = send(fd,tmp.get(),strlen(tmp.get())+1,0);

		if(res > 0)
		{
			print("发送成功......\n");
		}
		else if(res == 0)
		{
			print("连接断开......\n");
			break;
		}
		else
		{
		        print("发送错误......\n");
		}
	}

	fclose(pub);
	delete[]buffer;
	mpz_clears(key_n,key_e,text,nullptr);
	Close(fd);

	return 0;
}
//server.cpp
//服务端程序

#include<iostream>
using namespace std;
#include<sys/socket.h>
#include<unistd.h>
#include<sys/types.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<mutex>
#include<thread>
#include<vector>
#include<cstring>
#include<gmp.h>
#include<memory>

//线程安全的输出函数
mutex print_mutex{};
template<class...Args>
void print(const Args&...args)
{
	lock_guard<mutex> m{print_mutex};
	((cout<<args),...);
}


//发生错误函数
void unix_error(const char* msg)
{
	print(msg);
	exit(0);
}


//线程安全的输入函数
mutex input_mutex{};
template<class T>
void input(T& val)
{
	lock_guard<mutex> m{input_mutex};
	cin >> val;
}


//安全的socket函数
int Socket(int domain,int type,int protocol)
{
	int fd = socket(domain,type,protocol);
	if(fd < 0)
	{
		unix_error("create socket error\n");
	}
	return fd;
}

//安全的bind函数
void Bind(int sockfd ,const sockaddr* my_addr,socklen_t addrlen)
{
	if(bind(sockfd,my_addr,addrlen) < 0)
	{
		unix_error("bind error\n");
	}
}


//安全的listen函数
void Listen(int soc,int flag)
{
	if(listen(soc,flag) < 0)
	{
		unix_error("listen error\n");
	}
}
	
//安全的close函数
mutex close_mutex{};
void Close(int file)
{
	lock_guard<mutex> m{close_mutex};
	if(close(file) < 0)
	{
		print("close file ",file," error\n");
	}
}

//安全的accept函数
int Accept(int sockfd,sockaddr* cliaddr,socklen_t* addrlen)
{
	int fd = 0;
	if((fd = accept(sockfd,cliaddr,addrlen)) < 0)
	{
		unix_error("accept error\n");
	}
	return fd;
}


//安全的inet_pton函数
void Inet_pton(int proto,const char* addr,void* dst)
{
	if(inet_pton(proto,addr,dst) < 0)
	{
		unix_error("address translate error\n");
	}
}



//安全的的inet_ntop函数
void Inet_ntop(int proto,void* src,char* dst,size_t len)
{
	if(!inet_ntop(proto,src,dst,len))
	{
		print("address translate error\n");
	}
}


//与客户端通信线程入口函数,只收不发
void chat(int client,const string addr,int port)
{

	mpz_t key_d,key_n,text;
	mpz_inits(key_d,key_n,text,nullptr);

	FILE* pri = fopen("./private_key","r");
	if(pri == nullptr)
	{
		print("open private key error\n");
		return;
	}

	if(mpz_inp_str(key_d,pri,10) <= 0)
	{
		print("read private key rror\n");
		return;
	}

	if(mpz_inp_str(key_n,pri,10) <= 0)
	{
		print("read private key error\n");
		return;
	}


	print(addr,":",port," connected\n");

	const size_t size = 1024;
	char* buffer = new char[size];
	unique_ptr<char> tmp{nullptr};
	int res = 0;
	

	while(1)
	{
		memset(buffer,0,size);
		res = recv(client,buffer,size,0);
		if(res > 0)
		{
			mpz_set_str(text,buffer,10);
			mpz_powm(text,text,key_d,key_n);
			tmp.reset(mpz_get_str(nullptr,10,text));
			print("receive message,size = ",strlen(tmp.get())," : ",tmp.get(),"  from ",addr,":",port,"\n");
		}
		else if(res == 0)
		{
  	                print(addr,":",port," disconnected\n");
		        break;
		}
		else
		{
			print(addr,":",port," error\n");
			break;
		}
	}

	Close(client);
	mpz_clears(text,key_n,key_d,nullptr);
	fclose(pri);
	delete[]buffer;
}


int main(int args,char*argv[],char*envp[])
{

	//建立并绑定端口,使用tcp连接
	int soc = Socket(AF_INET,SOCK_STREAM,0);

	sockaddr_in addr{};
	memset(&addr, 0, sizeof(addr));
	addr.sin_port = htons(8088);
	addr.sin_family = AF_INET;
	Inet_pton(AF_INET,"127.0.0.1",&addr.sin_addr); 

	Bind(soc,(sockaddr*)&addr,sizeof(addr));

	Listen(soc,10);

	//保存连接客户端的信息
	sockaddr_in client;
	socklen_t len = sizeof(client);
	char client_addr[20];
	int client_port = 0;
	int client_socket = 0;


	//每有一个连接就创建一个线程来进行通信
	while(client_socket = Accept(soc,(sockaddr*)&client,&len))
	{
		memset(client_addr,0,sizeof(client_addr));
		Inet_ntop(AF_INET,&client.sin_addr,client_addr,20);
		client_port = ntohs(client.sin_port);
		thread m(chat,client_socket,client_addr,client_port);
		m.detach();
	}

	Close(soc);

	return 0;
}
//makefile
//编译条件

key : key.cpp
	 g++ key.cpp -o key -w -std=c++2a  -O2 -lgmp

clean :
	rm -rf  key server client

server : server.cpp
	g++ server.cpp -o server -w -std=c++2a  -O2 -lpthread -lgmp


client : client.cpp
	g++ client.cpp -o client -w -std=c++2a  -O2 -lpthread -lgmp