diff --git a/python/damask/_vtk.py b/python/damask/_vtk.py index f4855820e..6af4344f3 100644 --- a/python/damask/_vtk.py +++ b/python/damask/_vtk.py @@ -1,3 +1,4 @@ +import multiprocessing as mp from pathlib import Path import pandas as pd @@ -157,8 +158,11 @@ class VTK: return VTK(geom) - - def write(self,fname): + @staticmethod + def _write(writer): + """Wrapper for parallel writing.""" + writer.Write() + def write(self,fname,parallel=True): """ Write to file. @@ -166,6 +170,8 @@ class VTK: ---------- fname : str Filename for writing. + parallel : boolean, optional + Write data in parallel background process. Defaults to True. """ if isinstance(self.geom,vtk.vtkRectilinearGrid): @@ -183,8 +189,11 @@ class VTK: writer.SetCompressorTypeToZLib() writer.SetDataModeToBinary() writer.SetInputData(self.geom) - - writer.Write() + if parallel: + mp_writer = mp.Process(target=self._write,args=(writer,)) + mp_writer.start() + else: + writer.Write() # Check https://blog.kitware.com/ghost-and-blanking-visibility-changes/ for missing data @@ -195,14 +204,21 @@ class VTK: N_cells = self.geom.GetNumberOfCells() if isinstance(data,np.ndarray): - d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True) if label is None: raise ValueError('No label defined for numpy.ndarray') + + if data.dtype in [np.float64, np.float128]: # avoid large files + d = np_to_vtk(num_array=data.astype(np.float32).reshape(data.shape[0],-1),deep=True) + else: + d = np_to_vtk(num_array=data.reshape(data.shape[0],-1),deep=True) d.SetName(label) + if data.shape[0] == N_cells: self.geom.GetCellData().AddArray(d) elif data.shape[0] == N_points: self.geom.GetPointData().AddArray(d) + else: + raise ValueError(f'Invalid shape {data.shape[0]}') elif isinstance(data,pd.DataFrame): raise NotImplementedError('pd.DataFrame') elif isinstance(data,Table): diff --git a/python/tests/test_VTK.py b/python/tests/test_VTK.py index 8795e7161..395102950 100644 --- a/python/tests/test_VTK.py +++ b/python/tests/test_VTK.py @@ -1,4 +1,6 @@ import os +import filecmp +import time import pytest import numpy as np @@ -18,7 +20,7 @@ class TestVTK: origin = np.random.random(3) v = VTK.from_rectilinearGrid(grid,size,origin) string = v.__repr__() - v.write(os.path.join(tmp_path,'rectilinearGrid')) + v.write(os.path.join(tmp_path,'rectilinearGrid'),False) vtr = VTK.from_file(os.path.join(tmp_path,'rectilinearGrid.vtr')) with open(os.path.join(tmp_path,'rectilinearGrid.vtk'),'w') as f: f.write(string) @@ -26,10 +28,10 @@ class TestVTK: assert(string == vtr.__repr__() == vtk.__repr__()) def test_polyData(self,tmp_path): - points = np.random.rand(3,100) + points = np.random.rand(100,3) v = VTK.from_polyData(points) string = v.__repr__() - v.write(os.path.join(tmp_path,'polyData')) + v.write(os.path.join(tmp_path,'polyData'),False) vtp = VTK.from_file(os.path.join(tmp_path,'polyData.vtp')) with open(os.path.join(tmp_path,'polyData.vtk'),'w') as f: f.write(string) @@ -48,14 +50,30 @@ class TestVTK: connectivity = np.random.choice(np.arange(n),n,False).reshape(-1,n) v = VTK.from_unstructuredGrid(nodes,connectivity,cell_type) string = v.__repr__() - v.write(os.path.join(tmp_path,'unstructuredGrid')) + v.write(os.path.join(tmp_path,'unstructuredGrid'),False) vtu = VTK.from_file(os.path.join(tmp_path,'unstructuredGrid.vtu')) with open(os.path.join(tmp_path,'unstructuredGrid.vtk'),'w') as f: f.write(string) vtk = VTK.from_file(os.path.join(tmp_path,'unstructuredGrid.vtk'),'unstructuredgrid') assert(string == vtu.__repr__() == vtk.__repr__()) - @pytest.mark.parametrize('name,dataset_type',[('this_file_does_not_exist.vtk',None), + + def test_parallel_out(self,tmp_path): + points = np.random.rand(102,3) + v = VTK.from_polyData(points) + fname_s = os.path.join(tmp_path,'single.vtp') + fname_p = os.path.join(tmp_path,'parallel.vtp') + v.write(fname_s,False) + v.write(fname_p,True) + for i in range(10): + if os.path.isfile(fname_p) and filecmp.cmp(fname_s,fname_p): + assert(True) + return + time.sleep(.5) + assert(False) + + + @pytest.mark.parametrize('name,dataset_type',[('this_file_does_not_exist.vtk', None), ('this_file_does_not_exist.vtk','vtk'), ('this_file_does_not_exist.vtx', None)]) def test_invalid_dataset_type(self,dataset_type,name):