java多线程方式导入excel数据入数据库

需求:
最近项目写了poi导入excel数据到数据库,在代码上已经算是很优了,虽然领导没有要求我优化导入接口,但是本着技术而言,想把学到的知识用于实践,于是使用多线程方式导入excel。

所需pow依赖:

		<dependency>
			<groupId>org.apache.poi</groupId>
			<artifactId>poi</artifactId>
			<version>3.17</version>
		</dependency>
		<dependency>
			<groupId>org.apache.poi</groupId>
			<artifactId>poi-ooxml</artifactId>
			<version>3.17</version>
		</dependency>

导入的service实现类:

/**
     * 多线程导入
     * @param file
     * @return
     * @throws Exception
     */
    @Override
    public Map<String,Object> importData(MultipartFile file) throws Exception{
        final Date now = new Date();
        SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        logger.info("{},开始导入数据...", format.format(now));
        //设置一个信号量为5的信号量,限制同时运行的线程数量最大为5
        Semaphore semaphore = new Semaphore(10);

        Map<String,Object> map = new HashMap<>();
        //多线程编程需要一个线程安全的ArrayList
        List<ErrorInfo> list = Collections.synchronizedList(new ArrayList<ErrorInfo>());
        Workbook workbook = null;
        String filename = file.getOriginalFilename();
        if(filename.endsWith("xls")){
            workbook = new HSSFWorkbook(file.getInputStream());
        }else if(filename.endsWith("xlsx")){
            workbook = new XSSFWorkbook(file.getInputStream());
        }else {
            ErrorInfo errorInfo = new ErrorInfo();
            errorInfo.setErrorMsg("请上传xlx或xlsx格式的文件");
            list.add(errorInfo);
            map.put("code",501);
            map.put("data",list);
            return map;
        }
        Sheet sheet = workbook.getSheetAt(0);
        int physicalNumberOfRows = sheet.getPhysicalNumberOfRows();
        logger.info("获取到workbook中的总行数:{}" ,physicalNumberOfRows);
        //第一行是表头,实际行数要减1
        int rows = physicalNumberOfRows - 1;
        //一个线程让他处理200个row,也许可以处理更多吧
        int threadNum = rows/200 + 1; //线程数量

        //设置一个倒计时门闩,用来处理主线程等待蚂蚁线程执行完成工作之后再运行
        CountDownLatch countDownLatch = new CountDownLatch(threadNum);
        //查询是否重名
        Set<String> names = this.findAllUser().stream().map(User::getUsername).collect(Collectors.toSet());
        //创建一个定长的线程池
        ExecutorService executorService = Executors.newFixedThreadPool(threadNum);

        logger.info("开始创建线程,数据总行数:{},线程数量:{}",rows,threadNum);

        List<Future<Integer>> futures = new ArrayList<>();
        int successCount = 0;

        for(int i = 1; i <= threadNum; i++){

            int startRow = (i-1)*200 +1;
            int endRow = i*200;
            if(i == threadNum){
                endRow = rows;
            }
            logger.info("开始执行线程方法,线程ID:<{}>,线程名称:<{}>",Thread.currentThread().getId(),Thread.currentThread().getName());
            Future<Integer> future = executorService.submit(new UserThread(semaphore,workbook, startRow, endRow, list, names,this,countDownLatch));
            futures.add(future);
            logger.info("结束线程执行方法,返回结果:<{}>,当前线程ID:<{}>,当前线程名称:<{}>",JSON.toJSONString(future),Thread.currentThread().getId(),Thread.currentThread().getName());
            //get方法中可以设置超时时间,即规定时间内没有返回结果,则继续运行
            //get方法是线程阻塞的,调用get方法会导致后续线程因主线程阻塞而没有创建,达不到效果。
            //successCount += future.get();
        }
        //主线程等待子线程完成任务,60秒还没执行完成就继续执行

        for(Future<Integer> future : futures){
                successCount += future.get();
        }
		//主线程等待子线程全部跑完才继续运行。设置60秒等待时间,超时后继续执行。
        countDownLatch.await(60,TimeUnit.SECONDS);
        executorService.shutdown();

        Date endDate = new Date();
        long difference = endDate.getTime() - now.getTime();
        String duration = DurationFormatUtils.formatDuration(difference, "HH:mm:ss");
        logger.info("执行完成,错误信息:{}", JSON.toJSONString(list));
        logger.info("{},结束导入,共{}条数据,导入成功:{},耗时={}", format.format(endDate), rows,successCount,duration);
        map.put("code",200);
        map.put("msg","结束导入,共" + rows + "条数据,导入成功" + successCount + "条,耗时:" +duration);
        map.put("data",list);
        return map;
    }

导入线程类:

package com.thread.demo.thread;

import com.thread.demo.common.ErrorInfo;
import com.thread.demo.entity.User;
import com.thread.demo.service.UserService;
import org.apache.poi.ss.usermodel.Cell;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;

/**
 * @Author Honey
 * @Date 2019/11/15 10:31
 * @Description
 **/
public class UserThread implements Callable<Integer>{

    private Logger logger = LoggerFactory.getLogger(UserThread.class);

    private Workbook workbook;

    private Integer startRow;

    private Integer endRow;

    private List<ErrorInfo> errorInfoList;

    private Set<String> names;

    private UserService userService;

    private Semaphore semaphore;

    private CountDownLatch latch;

    public UserThread(Semaphore semaphore,Workbook workbook,Integer startRow,Integer endRow,List<ErrorInfo> errorInfoList,Set<String> names,UserService userService,CountDownLatch latch){
        this.workbook = workbook;
        this.startRow = startRow;
        this.endRow = endRow;
        this.errorInfoList = errorInfoList;
        this.names = names;
        this.userService = userService;
        this.semaphore = semaphore;
        this.latch = latch;
    }

    @Override
    public Integer call() throws Exception {
        logger.info("线程ID:<{}>开始运行,startRow:{},endRow:{}",Thread.currentThread().getId(),startRow,endRow);
        semaphore.acquire();
        logger.info("消耗了一个信号量,剩余信号量为:{}",semaphore.availablePermits());
        latch.countDown();
        Sheet sheet = workbook.getSheetAt(0);
        int count = 0;
        for(int i = startRow; i <= endRow; i++){
            User user = new User();
            Row row = sheet.getRow(i);
            Cell cell1 = row.getCell(0);
            String username = cell1.getStringCellValue();
            user.setUsername(username);
            user.setPassword("123456");
            Cell cell2 = row.getCell(1);
            String realname = cell2.getStringCellValue();
            user.setRealName(realname);
            if(names.contains(username)){
                ErrorInfo errorInfo = new ErrorInfo();
                errorInfo.setRow(startRow);
                errorInfo.setColumn(1);
                errorInfo.setErrorMsg("第" + startRow + "行用户账号已存在");
                errorInfoList.add(errorInfo);
                break;
            }
            count += userService.addUser(user);
        }
        semaphore.release();
        return count;
    }
}

controller也贴一下吧。没什么东西

package com.thread.demo.controller;

import com.thread.demo.service.UserService;
import org.apache.poi.ss.usermodel.Workbook;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

import javax.servlet.http.HttpServletResponse;
import java.util.Map;

/**
 * @Author Honey
 * @Date 2019/11/15 10:27
 * @Description
 **/
@RestController
public class UserController {

    @Autowired
    private UserService userService;

    /**
     * 多线程导入
     * @param file
     * @return
     */
    @PostMapping("/importManyThread")
    public Map importData(MultipartFile file){
        Map<String, Object> map = null;
        try {
            map = userService.importData(file);
            return map;
        } catch (Exception e) {
            e.printStackTrace();
            map.put("code",501);
            map.put("msg","数据出错");
            return map;
        }
    }

    /**
     * 单线程导入
     * @param file
     * @return
     */
    @PostMapping("/importSingleThread")
    public Map importData2(MultipartFile file){
        Map<String, Object> map = null;
        try {
            map = userService.importDataYiBan(file);
            return map;
        } catch (Exception e) {
            e.printStackTrace();
            map.put("code",501);
            map.put("msg","数据出错");
            return map;
        }
    }

    /**
     * 导出excel
     * @param response
     * @throws Exception
     */
    @GetMapping("/export")
    public void exportData(HttpServletResponse response) throws Exception{
        Workbook workbook = userService.exportData();
        response.setContentType("application/vnd.ms-excel;charset=utf-8");
        response.setCharacterEncoding("UTF-8");
        //test.xls是弹出下载对话框的文件名,不能为中文,中文请自行编码
        response.setHeader("Content-Disposition", "attachment;filename=user.xlsx");
        workbook.write(response.getOutputStream());
    }

}

执行结果:
在这里插入图片描述

使用多线程方式导入5000条数据花费时间14秒,而单线程导入则需1分钟14秒。
可见多线程方式运行程序是可以达到空间换时间的目的的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章