没啥好说的, 直接看代码:
import os
import gzip
import shutil
import paramiko
import multiprocessing
from pathlib import Path
from loguru import logger
from config import globalconf
class SSHConnection(object):
def __init__(self, host=None, port=None, username=None, pwd=None, pk_path=None):
"""
:param host: 服务器ip
:param port: 接口
:param username: 登录名
:param pwd: 密码
"""
self.host = host
self.port = port
self.username = username
self.pwd = pwd
self.pk_path = pk_path
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return True
def connect(self):
transport = paramiko.Transport((self.host, self.port))
# transport.connect(username=self.username, password=self.pwd)
pk = paramiko.RSAKey.from_private_key_file(self.pk_path)
transport.connect(username=self.username, pkey=pk)
self.__transport = transport
self.sftp = paramiko.SFTPClient.from_transport(self.__transport)
def close(self):
self.__transport.close()
self.sftp.close()
def upload(self, local_path, target_path):
self.sftp.put(local_path, target_path)
def download(self, remote_path, local_path):
sftp = paramiko.SFTPClient.from_transport(self.__transport)
sftp.get(remote_path, local_path)
def listdir(self, path):
return self.sftp.listdir(path=path)
def listdir_attr(self, path):
return self.sftp.listdir_attr(path=path)
def download_slowly(self, remote_path, local_path):
sftp = paramiko.SFTPClient.from_transport(self.__transport)
# # 旧方法下载大文件会出现Server connection dropped
# sftp.get(remote_path, local_path)
# 新方法下载大文件成功
# 这将避免Paramiko预取缓存,并允许您下载文件,即使它不是很快
with sftp.open(remote_path, 'rb') as fp:
shutil.copyfileobj(fp, open(local_path, 'wb'))
def cmd(self, command):
ssh = paramiko.SSHClient()
# 执行命令
stdin, stdout, stderr = ssh.exec_command(command)
# 获取命令结果
result = stdout.read()
result = str(result, encoding='utf-8')
return result
class SSHConnectionManager(object):
def __init__(self, host, port, username, pwd, pk_path):
self.ssh_args = {
"host": host,
"port": port,
"username": username,
"pk_path": pk_path,
"pwd": pwd
}
def __enter__(self):
self.ssh = SSHConnection(**self.ssh_args)
self.ssh.connect()
return self.ssh
def __exit__(self, exc_type, exc_val, exc_tb):
self.ssh.close()
return True
def clear_dir(path):
"""
清空文件夹:如果文件夹不存在就创建,如果文件存在就清空!
:param path: 文件夹路径
:return:
"""
import os
import shutil
try:
if not os.path.exists(path):
os.makedirs(path)
else:
shutil.rmtree(path)
os.makedirs(path)
return True
except:
return False
def decompress_gz(gz_file):
with gzip.GzipFile(gz_file) as file:
for i in file:
yield i
def get_file_sftp(tar_path):
temp_path = os.path.join(globalconf.SFTP_TEMP_PATH, os.path.split(tar_path)[-1])
try:
with SSHConnection(
host=globalconf.SFTP_SERVER_HOST,
port=globalconf.SFTP_SERVER_PORT,
username=globalconf.SFTP_USERNAME,
pwd="",
pk_path=globalconf.SFTP_KEY_PATH
) as sftp:
sftp.download_slowly(tar_path, temp_path)
except Exception as err:
logger.error(f"sftp 下载文件出错: {err}")
return temp_path
def format_file_list(file_list):
root_dir = os.path.dirname(file_list[0])
sub_files = set([os.path.basename(file) for file in file_list])
return root_dir, sub_files
def is_sftp_file_exists(task_id, path):
exists = False
if isinstance(path, str) and path:
path = [path]
rootdir, _ = format_file_list(path)
try:
with SSHConnection(
host=globalconf.SFTP_SERVER_HOST,
port=globalconf.SFTP_SERVER_PORT,
username=globalconf.SFTP_USERNAME,
pwd="",
pk_path=globalconf.SFTP_KEY_PATH
) as sftp:
list_file_names = sftp.listdir(rootdir)
if str(task_id) in str(list_file_names):
exists = True
except Exception as err:
logger.error(f"sftp 获取文件列表错误: {err}")
return exists
class SFTPFileManager_Tool(object):
def __init__(self, host, port, username, pwd, pk_path):
"""
init
:param host: ip
:param port: 端口
:param username: 用户名
:param pwd: 密码
"""
self.ssh_args = {"host": host, "port": port, "username": username, "pwd": pwd, "pk_path": pk_path}
def exists(self, path):
"""
判断路径是否存在
:param path:
:return:
"""
is_exists = False
with SSHConnectionManager(**self.ssh_args) as ssh:
result = ssh.cmd(f"find {path}")
if result:
is_exists = True
return is_exists
def is_file(self, path):
"""
判断路径是否是文件
:param path:
:return:
"""
if self.exists(path):
with SSHConnectionManager(**self.ssh_args) as ssh:
prefix = ssh.cmd(f"ls -ld {path}")[0]
if prefix == '-':
return True
else:
return False
else:
return False
def is_dir(self, path):
"""
判断路径是否是目录
:param path:
:return:
"""
if self.exists(path):
with SSHConnectionManager(**self.ssh_args) as ssh:
prefix = ssh.cmd(f"ls -ld {path}")[0]
if prefix == 'd':
return True
else:
return False
else:
return False
def download_file(self, remote_path, local_path):
"""
下载文件
:param remote_path: 远程文件路径
:param local_path: 本地文件路径
:return:
"""
print(f"正下载文件{remote_path}...")
with SSHConnectionManager(**self.ssh_args) as ssh:
if not Path(local_path).parent.exists():
Path(local_path).parent.mkdir(parents=True)
ssh.download_slowly(remote_path=remote_path, local_path=local_path)
def download_folder(self, remote_folder, local_folder):
"""
下载文件夹
:param remote_folder: 远程文件夹目录
:param local_folder: 本地文件夹目录
:return:
"""
with SSHConnectionManager(**self.ssh_args) as ssh:
dst_folder = Path(local_folder)
if not dst_folder.exists():
dst_folder.mkdir(parents=True)
clear_dir(str(dst_folder))
files_list = ssh.cmd("ls {}".format(remote_folder))
files_list = files_list.split('\n')
files_list = [x for x in files_list if x]
# 多进程下载文件
cpu_count = multiprocessing.cpu_count() // 2
if cpu_count == 0:
cpu_count = 1
pool = multiprocessing.Pool(cpu_count)
for file in files_list:
remote_path = remote_folder + "/" + file
local_path = dst_folder.joinpath(file)
pool.apply_async(func=self.download_file, args=(remote_path, local_path))
pool.close()
pool.join()
# 主进程下载文件
# for file in files_list:
# remote_path = remote_folder + "/" + file
# local_path = dst_folder.joinpath(file)
# print(f"正下载{remote_path}...")
# ssh.download_slowly(remote_path=remote_path, local_path=local_path)
if __name__ == "__main__":
a = SFTPFileManager_Tool(host="1.1.1.1", port=5050, username="sftpuser", pwd="", pk_path="/root/.ssh/id_rsa")
b = a.is_dir("/hello/")
print(b)