CSAPP–第十一章–网络编程(下)

Aki 发布于 2023-01-29 291 次阅读


socket多线程编程服务器端

#include<iostream>
using namespace std;
#include<sys/socket.h>
#include<unistd.h>
#include<sys/types.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<mutex>
#include<thread>
#include<cstring>



//线程安全的输出函数
mutex print_mutex{};
template<class...Args>
void print(const Args&...args)
{
	lock_guard<mutex> m{print_mutex};
	((cout<<args),...);
}


//发生错误处理函数
void unix_error(const char* msg)
{
	print(msg);
	exit(0);
}


//线程安全的输入函数
mutex input_mutex{};
template<class T>
void input(T& val)
{
	lock_guard<mutex> m{input_mutex};
	cin >> val;
}


//安全的socket函数
int Socket(int domain,int type,int protocol)
{
	int fd = socket(domain,type,protocol);
	if(fd < 0)
	{
		unix_error("create socket error\n");
	}
	return fd;
}

//安全的bind函数
void Bind(int sockfd ,const sockaddr* my_addr,socklen_t addrlen)
{
	if(bind(sockfd,my_addr,addrlen) < 0)
	{
		unix_error("bind error\n");
	}
}


//安全的listen函数
void Listen(int soc,int flag)
{
	if(listen(soc,flag) < 0)
	{
		unix_error("listen error\n");
	}
}
	
//安全的close函数
mutex close_mutex{};
void Close(int file)
{
	lock_guard<mutex> m{close_mutex};
	if(close(file) < 0)
	{
		print("close file ",file," error\n");
	}
}

//安全的accept函数
int Accept(int sockfd,sockaddr* cliaddr,socklen_t* addrlen)
{
	int fd = 0;
	if((fd = accept(sockfd,cliaddr,addrlen)) < 0)
	{
		unix_error("accept error\n");
	}
	return fd;
}


//安全的inet_pton函数
void Inet_pton(int proto,const char* addr,void* dst)
{
	if(inet_pton(proto,addr,dst) < 0)
	{
		unix_error("address translate error\n");
	}
}



//安全的的inet_ntop函数
void Inet_ntop(int proto,void* src,char* dst,size_t len)
{
	if(!inet_ntop(proto,src,dst,len))
	{
		print("address translate error\n");
	}
}


//与客户端通信线程入口函数,只收不发
void chat(int client,const string addr,int port)
{
	print(addr,":",port," connected\n");
	char buffer[512];
	int res = 0;
	while(1)
	{
		res = recv(client,buffer,512,0);
		if(res > 0)
		{
			print("receive message: ",buffer,". from ",addr,":",port,"\n");
		}
		else if(res == 0)
		{
  	                print(addr,":",port," disconnected\n");
		        break;
		}
		else
		{
			print(addr,":",port," error\n");
			break;
		}

	}
	Close(client);
}


int main(int args,char*argv[],char*envp[])
{
	int soc = Socket(AF_INET,SOCK_STREAM,0);

	sockaddr_in addr{};
	memset(&addr, 0, sizeof(addr));
	addr.sin_port = htons(8088);
	addr.sin_family = AF_INET;
	Inet_pton(AF_INET,"0.0.0.0",&addr.sin_addr); 

	Bind(soc,(sockaddr*)&addr,sizeof(addr));

	Listen(soc,10);

	//保存连接客户端的信息
	sockaddr_in client;
	socklen_t len = sizeof(client);
	char client_addr[20];
	int client_port = 0;
	int client_socket = 0;


	while(client_socket = Accept(soc,(sockaddr*)&client,&len))
	{
		Inet_ntop(AF_INET,&client.sin_addr,client_addr,20);
		client_port = ntohs(client.sin_port);
		thread m(chat,client_socket,client_addr,client_port);
		m.detach();
	}

	Close(soc);


	return 0;
}

client端

#include<iostream>
using namespace std;
#include<sys/socket.h>
#include<unistd.h>
#include<sys/types.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<mutex>
#include<thread>
#include<vector>
#include<cstring>



//线程安全的输出函数
mutex print_mutex{};
template<class...Args>
void print(const Args&...args)
{
	lock_guard<mutex> m{print_mutex};
	((cout<<args),...);
}


//发生错误函数
void unix_error(const char* msg)
{
	print(msg);
	exit(0);
}


//线程安全的输入函数
mutex input_mutex{};
template<class T>
void input(T& val)
{
	lock_guard<mutex> m{input_mutex};
	cin >> val;
}


//安全的socket函数
int Socket(int domain,int type,int protocol)
{
	int fd = socket(domain,type,protocol);
	if(fd < 0)
	{
		unix_error("create socket error\n");
	}
	return fd;
}

//安全的bind函数
void Bind(int sockfd ,const sockaddr* my_addr,socklen_t addrlen)
{
	if(bind(sockfd,my_addr,addrlen) < 0)
	{
		unix_error("bind error\n");
	}
}


//安全的listen函数
void Listen(int soc,int flag)
{
	if(listen(soc,flag) < 0)
	{
		unix_error("listen error\n");
	}
}
	
//安全的close函数
mutex close_mutex{};
void Close(int file)
{
	lock_guard<mutex> m{close_mutex};
	if(close(file) < 0)
	{
		print("close file ",file," error\n");
	}
}

//安全的accept函数
int Accept(int sockfd,sockaddr* cliaddr,socklen_t* addrlen)
{
	int fd = 0;
	if((fd = accept(sockfd,cliaddr,addrlen)) < 0)
	{
		unix_error("accept error\n");
	}
	return fd;
}


//安全的inet_pton函数
void Inet_pton(int proto,const char* addr,void* dst)
{
	if(inet_pton(proto,addr,dst) < 0)
	{
		unix_error("address translate error\n");
	}
}



//安全的的inet_ntop函数
void Inet_ntop(int proto,void* src,char* dst,size_t len)
{
	if(!inet_ntop(proto,src,dst,len))
	{
		print("address translate error\n");
	}
}



int main(int args,char*argv[],char*envp[])
{


	int fd = Socket(AF_INET,SOCK_STREAM,0);

	sockaddr_in addr{};
	addr.sin_port = htons(8088);
	addr.sin_family = AF_INET;
	Inet_pton(AF_INET,"127.0.0.1",&addr.sin_addr);

	if(connect(fd,(sockaddr*)&addr,sizeof(addr)) < 0)
	{
		unix_error("连接失败......\n");
	}
	else
	{
		print("连接成功......\n");
	}

	char buffer[512];
	while(1)
	{
		memset(&buffer,0,sizeof(buffer));
		print("请输入需要发送的信息:\n");
		input(buffer);
		if(strcmp(buffer,"quit") == 0)
		{
			print("结束通话......\n");
			break;
		}
		int res = send(fd,buffer,strlen(buffer)+1,0);
		if(res > 0)
		{
			print("发送成功......\n");
		}
		else if(res == 0)
		{
			print("连接断开......\n");
			break;
		}
		else
		{
		        print("发送错误......\n");
		}

	}

	Close(fd);
	return 0;
}