speed up vtk out

- limit to single precision
- write in background
This commit is contained in:
Martin Diehl 2020-06-26 11:45:54 +02:00
parent bfae88a364
commit a69f82e7c3
2 changed files with 44 additions and 10 deletions

View File

@ -1,3 +1,4 @@
import multiprocessing as mp
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
@ -157,8 +158,11 @@ class VTK:
return VTK(geom) return VTK(geom)
@staticmethod
def write(self,fname): def _write(writer):
"""Wrapper for parallel writing."""
writer.Write()
def write(self,fname,parallel=True):
""" """
Write to file. Write to file.
@ -166,6 +170,8 @@ class VTK:
---------- ----------
fname : str fname : str
Filename for writing. Filename for writing.
parallel : boolean, optional
Write data in parallel background process. Defaults to True.
""" """
if isinstance(self.geom,vtk.vtkRectilinearGrid): if isinstance(self.geom,vtk.vtkRectilinearGrid):
@ -183,7 +189,10 @@ class VTK:
writer.SetCompressorTypeToZLib() writer.SetCompressorTypeToZLib()
writer.SetDataModeToBinary() writer.SetDataModeToBinary()
writer.SetInputData(self.geom) writer.SetInputData(self.geom)
if parallel:
mp_writer = mp.Process(target=self._write,args=(writer,))
mp_writer.start()
else:
writer.Write() writer.Write()
@ -195,14 +204,21 @@ class VTK:
N_cells = self.geom.GetNumberOfCells() N_cells = self.geom.GetNumberOfCells()
if isinstance(data,np.ndarray): if isinstance(data,np.ndarray):
d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True)
if label is None: if label is None:
raise ValueError('No label defined for numpy.ndarray') raise ValueError('No label defined for numpy.ndarray')
if data.dtype in [np.float64, np.float128]: # avoid large files
d = np_to_vtk(num_array=data.astype(np.float32).reshape(data.shape[0],-1),deep=True)
else:
d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True)
d.SetName(label) d.SetName(label)
if data.shape[0] == N_cells: if data.shape[0] == N_cells:
self.geom.GetCellData().AddArray(d) self.geom.GetCellData().AddArray(d)
elif data.shape[0] == N_points: elif data.shape[0] == N_points:
self.geom.GetPointData().AddArray(d) self.geom.GetPointData().AddArray(d)
else:
raise ValueError(f'Invalid shape {data.shape[0]}')
elif isinstance(data,pd.DataFrame): elif isinstance(data,pd.DataFrame):
raise NotImplementedError('pd.DataFrame') raise NotImplementedError('pd.DataFrame')
elif isinstance(data,Table): elif isinstance(data,Table):

View File

@ -1,4 +1,6 @@
import os import os
import filecmp
import time
import pytest import pytest
import numpy as np import numpy as np
@ -18,7 +20,7 @@ class TestVTK:
origin = np.random.random(3) origin = np.random.random(3)
v = VTK.from_rectilinearGrid(grid,size,origin) v = VTK.from_rectilinearGrid(grid,size,origin)
string = v.__repr__() string = v.__repr__()
v.write(os.path.join(tmp_path,'rectilinearGrid')) v.write(os.path.join(tmp_path,'rectilinearGrid'),False)
vtr = VTK.from_file(os.path.join(tmp_path,'rectilinearGrid.vtr')) vtr = VTK.from_file(os.path.join(tmp_path,'rectilinearGrid.vtr'))
with open(os.path.join(tmp_path,'rectilinearGrid.vtk'),'w') as f: with open(os.path.join(tmp_path,'rectilinearGrid.vtk'),'w') as f:
f.write(string) f.write(string)
@ -26,10 +28,10 @@ class TestVTK:
assert(string == vtr.__repr__() == vtk.__repr__()) assert(string == vtr.__repr__() == vtk.__repr__())
def test_polyData(self,tmp_path): def test_polyData(self,tmp_path):
points = np.random.rand(3,100) points = np.random.rand(100,3)
v = VTK.from_polyData(points) v = VTK.from_polyData(points)
string = v.__repr__() string = v.__repr__()
v.write(os.path.join(tmp_path,'polyData')) v.write(os.path.join(tmp_path,'polyData'),False)
vtp = VTK.from_file(os.path.join(tmp_path,'polyData.vtp')) vtp = VTK.from_file(os.path.join(tmp_path,'polyData.vtp'))
with open(os.path.join(tmp_path,'polyData.vtk'),'w') as f: with open(os.path.join(tmp_path,'polyData.vtk'),'w') as f:
f.write(string) f.write(string)
@ -48,14 +50,30 @@ class TestVTK:
connectivity = np.random.choice(np.arange(n),n,False).reshape(-1,n) connectivity = np.random.choice(np.arange(n),n,False).reshape(-1,n)
v = VTK.from_unstructuredGrid(nodes,connectivity,cell_type) v = VTK.from_unstructuredGrid(nodes,connectivity,cell_type)
string = v.__repr__() string = v.__repr__()
v.write(os.path.join(tmp_path,'unstructuredGrid')) v.write(os.path.join(tmp_path,'unstructuredGrid'),False)
vtu = VTK.from_file(os.path.join(tmp_path,'unstructuredGrid.vtu')) vtu = VTK.from_file(os.path.join(tmp_path,'unstructuredGrid.vtu'))
with open(os.path.join(tmp_path,'unstructuredGrid.vtk'),'w') as f: with open(os.path.join(tmp_path,'unstructuredGrid.vtk'),'w') as f:
f.write(string) f.write(string)
vtk = VTK.from_file(os.path.join(tmp_path,'unstructuredGrid.vtk'),'unstructuredgrid') vtk = VTK.from_file(os.path.join(tmp_path,'unstructuredGrid.vtk'),'unstructuredgrid')
assert(string == vtu.__repr__() == vtk.__repr__()) assert(string == vtu.__repr__() == vtk.__repr__())
@pytest.mark.parametrize('name,dataset_type',[('this_file_does_not_exist.vtk',None),
def test_parallel_out(self,tmp_path):
points = np.random.rand(102,3)
v = VTK.from_polyData(points)
fname_s = os.path.join(tmp_path,'single.vtp')
fname_p = os.path.join(tmp_path,'parallel.vtp')
v.write(fname_s,False)
v.write(fname_p,True)
for i in range(10):
if os.path.isfile(fname_p) and filecmp.cmp(fname_s,fname_p):
assert(True)
return
time.sleep(.5)
assert(False)
@pytest.mark.parametrize('name,dataset_type',[('this_file_does_not_exist.vtk', None),
('this_file_does_not_exist.vtk','vtk'), ('this_file_does_not_exist.vtk','vtk'),
('this_file_does_not_exist.vtx', None)]) ('this_file_does_not_exist.vtx', None)])
def test_invalid_dataset_type(self,dataset_type,name): def test_invalid_dataset_type(self,dataset_type,name):