rename: Geom -> Grid
This commit is contained in:
parent
0fdefa5e78
commit
171d642dbd
2
PRIVATE
2
PRIVATE
|
@ -1 +1 @@
|
|||
Subproject commit fc27bbd6e028aa73545327aebdb206840063e135
|
||||
Subproject commit a0475c50bfaf6f86f75345754188918a6e9d7134
|
|
@ -65,7 +65,7 @@ if filenames == []: parser.error('no input file specified.')
|
|||
for name in filenames:
|
||||
damask.util.report(scriptName,name)
|
||||
|
||||
geom = damask.Geom.load_DREAM3D(name,options.basegroup,options.pointwise)
|
||||
geom = damask.Grid.load_DREAM3D(name,options.basegroup,options.pointwise)
|
||||
damask.util.croak(geom)
|
||||
|
||||
geom.save_ASCII(os.path.splitext(name)[0]+'.geom')
|
||||
|
|
|
@ -133,7 +133,7 @@ for i in range(3,np.max(microstructure)):
|
|||
|
||||
header = [scriptID + ' ' + ' '.join(sys.argv[1:])]\
|
||||
+ config_header
|
||||
geom = damask.Geom(microstructure.reshape(grid),
|
||||
geom = damask.Grid(microstructure.reshape(grid),
|
||||
size,-size/2,
|
||||
comments=header)
|
||||
damask.util.croak(geom)
|
||||
|
|
|
@ -62,7 +62,7 @@ if filenames == []: filenames = [None]
|
|||
for name in filenames:
|
||||
damask.util.report(scriptName,name)
|
||||
|
||||
geom = damask.Geom.load(StringIO(''.join(sys.stdin.read())) if name is None else name)
|
||||
geom = damask.Grid.load(StringIO(''.join(sys.stdin.read())) if name is None else name)
|
||||
|
||||
grid_original = geom.cells
|
||||
damask.util.croak(geom)
|
||||
|
@ -169,7 +169,7 @@ for name in filenames:
|
|||
# undo any changes involving immutable materials
|
||||
material = np.where(immutable, material_original,material)
|
||||
|
||||
damask.Geom(material = material[0:grid_original[0],0:grid_original[1],0:grid_original[2]],
|
||||
damask.Grid(material = material[0:grid_original[0],0:grid_original[1],0:grid_original[2]],
|
||||
size = geom.size,
|
||||
origin = geom.origin,
|
||||
comments = geom.comments + [scriptID + ' ' + ' '.join(sys.argv[1:])],
|
||||
|
|
|
@ -196,7 +196,7 @@ if filenames == []: filenames = [None]
|
|||
for name in filenames:
|
||||
damask.util.report(scriptName,name)
|
||||
|
||||
geom = damask.Geom.load(StringIO(''.join(sys.stdin.read())) if name is None else name)
|
||||
geom = damask.Grid.load(StringIO(''.join(sys.stdin.read())) if name is None else name)
|
||||
material = geom.material.flatten(order='F')
|
||||
|
||||
cmds = [\
|
||||
|
|
|
@ -91,7 +91,7 @@ class myThread (threading.Thread):
|
|||
perturbedSeedsTable.set('pos',coords).save(perturbedSeedsVFile,legacy=True)
|
||||
|
||||
#--- do tesselation with perturbed seed file ------------------------------------------------------
|
||||
perturbedGeom = damask.Geom.from_Voronoi_tessellation(options.grid,np.ones(3),coords)
|
||||
perturbedGeom = damask.Grid.from_Voronoi_tessellation(options.grid,np.ones(3),coords)
|
||||
|
||||
|
||||
#--- evaluate current seeds file ------------------------------------------------------------------
|
||||
|
@ -210,7 +210,7 @@ baseFile = os.path.splitext(os.path.basename(options.seedFile))[0]
|
|||
points = np.array(options.grid).prod().astype('float')
|
||||
|
||||
# ----------- calculate target distribution and bin edges
|
||||
targetGeom = damask.Geom.load_ASCII(os.path.splitext(os.path.basename(options.target))[0]+'.geom')
|
||||
targetGeom = damask.Grid.load_ASCII(os.path.splitext(os.path.basename(options.target))[0]+'.geom')
|
||||
nMaterials = len(np.unique(targetGeom.material))
|
||||
targetVolFrac = np.bincount(targetGeom.material.flatten())/targetGeom.cells.prod().astype(np.float)
|
||||
target = []
|
||||
|
@ -229,7 +229,7 @@ bestSeedsUpdate = time.time()
|
|||
|
||||
# ----------- tessellate initial seed file to get and evaluate geom file
|
||||
bestSeedsVFile.seek(0)
|
||||
initialGeom = damask.Geom.from_Voronoi_tessellation(options.grid,np.ones(3),initial_seeds)
|
||||
initialGeom = damask.Grid.from_Voronoi_tessellation(options.grid,np.ones(3),initial_seeds)
|
||||
|
||||
if len(np.unique(targetGeom.material)) != nMaterials:
|
||||
damask.util.croak('error. Material count mismatch')
|
||||
|
|
|
@ -52,7 +52,7 @@ options.box = np.array(options.box).reshape(3,2)
|
|||
|
||||
for name in filenames:
|
||||
damask.util.report(scriptName,name)
|
||||
geom = damask.Geom.load_ASCII(StringIO(''.join(sys.stdin.read())) if name is None else name)
|
||||
geom = damask.Grid.load_ASCII(StringIO(''.join(sys.stdin.read())) if name is None else name)
|
||||
|
||||
offset =(np.amin(options.box, axis=1)*geom.cells/geom.size).astype(int)
|
||||
box = np.amax(options.box, axis=1) \
|
||||
|
|
|
@ -32,7 +32,7 @@ from ._vtk import VTK # noqa
|
|||
from ._colormap import Colormap # noqa
|
||||
from ._config import Config # noqa
|
||||
from ._configmaterial import ConfigMaterial # noqa
|
||||
from ._geom import Geom # noqa
|
||||
from ._grid import Grid # noqa
|
||||
from ._result import Result # noqa
|
||||
|
||||
|
||||
|
|
|
@ -16,21 +16,21 @@ from . import grid_filters
|
|||
from . import Rotation
|
||||
|
||||
|
||||
class Geom:
|
||||
class Grid:
|
||||
"""Geometry definition for grid solvers."""
|
||||
|
||||
def __init__(self,material,size,origin=[0.0,0.0,0.0],comments=[]):
|
||||
"""
|
||||
New geometry definition from array of materials, size, and origin.
|
||||
New grid definition from array of materials, size, and origin.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
material : numpy.ndarray
|
||||
Material index array (3D).
|
||||
size : list or numpy.ndarray
|
||||
Physical size of the geometry in meter.
|
||||
Physical size of the grid in meter.
|
||||
origin : list or numpy.ndarray, optional
|
||||
Physical origin of the geometry in meter.
|
||||
Physical origin of the grid in meter.
|
||||
comments : list of str, optional
|
||||
Comment lines.
|
||||
|
||||
|
@ -42,7 +42,7 @@ class Geom:
|
|||
|
||||
|
||||
def __repr__(self):
|
||||
"""Basic information on geometry definition."""
|
||||
"""Basic information on grid definition."""
|
||||
return util.srepr([
|
||||
f'cells a b c: {util.srepr(self.cells, " x ")}',
|
||||
f'size x y z: {util.srepr(self.size, " x ")}',
|
||||
|
@ -53,12 +53,12 @@ class Geom:
|
|||
|
||||
|
||||
def __copy__(self):
|
||||
"""Copy geometry."""
|
||||
"""Copy grid."""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def copy(self):
|
||||
"""Copy geometry."""
|
||||
"""Copy grid."""
|
||||
return self.__copy__()
|
||||
|
||||
|
||||
|
@ -68,8 +68,8 @@ class Geom:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
other : Geom
|
||||
Geometry to compare self against.
|
||||
other : damask.Grid
|
||||
Grid to compare self against.
|
||||
|
||||
"""
|
||||
message = []
|
||||
|
@ -117,7 +117,7 @@ class Geom:
|
|||
|
||||
@property
|
||||
def size(self):
|
||||
"""Physical size of geometry in meter."""
|
||||
"""Physical size of grid in meter."""
|
||||
return self._size
|
||||
|
||||
@size.setter
|
||||
|
@ -129,7 +129,7 @@ class Geom:
|
|||
|
||||
@property
|
||||
def origin(self):
|
||||
"""Coordinates of geometry origin in meter."""
|
||||
"""Coordinates of grid origin in meter."""
|
||||
return self._origin
|
||||
|
||||
@origin.setter
|
||||
|
@ -141,7 +141,7 @@ class Geom:
|
|||
|
||||
@property
|
||||
def comments(self):
|
||||
"""Comments/history of geometry."""
|
||||
"""Comments, e.g. history of operations."""
|
||||
return self._comments
|
||||
|
||||
@comments.setter
|
||||
|
@ -157,7 +157,7 @@ class Geom:
|
|||
|
||||
@property
|
||||
def N_materials(self):
|
||||
"""Number of (unique) material indices within geometry."""
|
||||
"""Number of (unique) material indices within grid."""
|
||||
return np.unique(self.material).size
|
||||
|
||||
|
||||
|
@ -169,8 +169,8 @@ class Geom:
|
|||
Parameters
|
||||
----------
|
||||
fname : str or or pathlib.Path
|
||||
Geometry file to read.
|
||||
Valid extension is .vtr, which will be appended if not given.
|
||||
Grid file to read. Valid extension is .vtr, which will be appended
|
||||
if not given.
|
||||
|
||||
"""
|
||||
v = VTK.load(fname if str(fname).endswith('.vtr') else str(fname)+'.vtr')
|
||||
|
@ -178,7 +178,7 @@ class Geom:
|
|||
cells = np.array(v.vtk_data.GetDimensions())-1
|
||||
bbox = np.array(v.vtk_data.GetBounds()).reshape(3,2).T
|
||||
|
||||
return Geom(material = v.get('material').reshape(cells,order='F'),
|
||||
return Grid(material = v.get('material').reshape(cells,order='F'),
|
||||
size = bbox[1] - bbox[0],
|
||||
origin = bbox[0],
|
||||
comments=comments)
|
||||
|
@ -248,7 +248,7 @@ class Geom:
|
|||
if not np.any(np.mod(material,1) != 0.0): # no float present
|
||||
material = material.astype('int') - (1 if material.min() > 0 else 0)
|
||||
|
||||
return Geom(material.reshape(cells,order='F'),size,origin,comments)
|
||||
return Grid(material.reshape(cells,order='F'),size,origin,comments)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
@ -282,7 +282,7 @@ class Geom:
|
|||
if point_data is None else \
|
||||
np.reshape(f[path.join(root_dir,base_group,point_data,material)],cells.prod())
|
||||
|
||||
return Geom(ma.reshape(cells,order='F'),size,origin,util.execution_stamp('Geom','load_DREAM3D'))
|
||||
return Grid(ma.reshape(cells,order='F'),size,origin,util.execution_stamp('Grid','load_DREAM3D'))
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
@ -310,7 +310,7 @@ class Geom:
|
|||
ma = np.arange(cells.prod()) if len(unique) == cells.prod() else \
|
||||
np.arange(unique.size)[np.argsort(pd.unique(unique_inverse))][unique_inverse]
|
||||
|
||||
return Geom(ma.reshape(cells,order='F'),size,origin,util.execution_stamp('Geom','from_table'))
|
||||
return Grid(ma.reshape(cells,order='F'),size,origin,util.execution_stamp('Grid','from_table'))
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
@ -327,7 +327,7 @@ class Geom:
|
|||
cells : int numpy.ndarray of shape (3)
|
||||
Number of cells in x,y,z direction.
|
||||
size : list or numpy.ndarray of shape (3)
|
||||
Physical size of the geometry in meter.
|
||||
Physical size of the grid in meter.
|
||||
seeds : numpy.ndarray of shape (:,3)
|
||||
Position of the seed points in meter. All points need to lay within the box.
|
||||
weights : numpy.ndarray of shape (seeds.shape[0])
|
||||
|
@ -336,7 +336,7 @@ class Geom:
|
|||
Material ID of the seeds.
|
||||
Defaults to None, in which case materials are consecutively numbered.
|
||||
periodic : Boolean, optional
|
||||
Perform a periodic tessellation. Defaults to True.
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
|
||||
"""
|
||||
if periodic:
|
||||
|
@ -351,7 +351,7 @@ class Geom:
|
|||
coords = grid_filters.coordinates0_point(cells,size).reshape(-1,3)
|
||||
|
||||
pool = mp.Pool(processes = int(environment.options['DAMASK_NUM_THREADS']))
|
||||
result = pool.map_async(partial(Geom._find_closest_seed,seeds_p,weights_p), [coord for coord in coords])
|
||||
result = pool.map_async(partial(Grid._find_closest_seed,seeds_p,weights_p), [coord for coord in coords])
|
||||
pool.close()
|
||||
pool.join()
|
||||
material_ = np.array(result.get())
|
||||
|
@ -362,9 +362,9 @@ class Geom:
|
|||
else:
|
||||
material_ = material_.reshape(cells)
|
||||
|
||||
return Geom(material = material_ if material is None else material[material_],
|
||||
return Grid(material = material_ if material is None else material[material_],
|
||||
size = size,
|
||||
comments = util.execution_stamp('Geom','from_Laguerre_tessellation'),
|
||||
comments = util.execution_stamp('Grid','from_Laguerre_tessellation'),
|
||||
)
|
||||
|
||||
|
||||
|
@ -378,23 +378,23 @@ class Geom:
|
|||
cells : int numpy.ndarray of shape (3)
|
||||
Number of cells in x,y,z direction.
|
||||
size : list or numpy.ndarray of shape (3)
|
||||
Physical size of the geometry in meter.
|
||||
Physical size of the grid in meter.
|
||||
seeds : numpy.ndarray of shape (:,3)
|
||||
Position of the seed points in meter. All points need to lay within the box.
|
||||
material : numpy.ndarray of shape (seeds.shape[0]), optional
|
||||
Material ID of the seeds.
|
||||
Defaults to None, in which case materials are consecutively numbered.
|
||||
periodic : Boolean, optional
|
||||
Perform a periodic tessellation. Defaults to True.
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
|
||||
"""
|
||||
coords = grid_filters.coordinates0_point(cells,size).reshape(-1,3)
|
||||
KDTree = spatial.cKDTree(seeds,boxsize=size) if periodic else spatial.cKDTree(seeds)
|
||||
devNull,material_ = KDTree.query(coords)
|
||||
|
||||
return Geom(material = (material_ if material is None else material[material_]).reshape(cells),
|
||||
return Grid(material = (material_ if material is None else material[material_]).reshape(cells),
|
||||
size = size,
|
||||
comments = util.execution_stamp('Geom','from_Voronoi_tessellation'),
|
||||
comments = util.execution_stamp('Grid','from_Voronoi_tessellation'),
|
||||
)
|
||||
|
||||
|
||||
|
@ -450,7 +450,7 @@ class Geom:
|
|||
cells : int numpy.ndarray of shape (3)
|
||||
Number of cells in x,y,z direction.
|
||||
size : list or numpy.ndarray of shape (3)
|
||||
Physical size of the geometry in meter.
|
||||
Physical size of the grid in meter.
|
||||
surface : str
|
||||
Type of the minimal surface. See notes for details.
|
||||
threshold : float, optional.
|
||||
|
@ -497,9 +497,9 @@ class Geom:
|
|||
periods*2.0*np.pi*(np.arange(cells[1])+0.5)/cells[1],
|
||||
periods*2.0*np.pi*(np.arange(cells[2])+0.5)/cells[2],
|
||||
indexing='ij',sparse=True)
|
||||
return Geom(material = np.where(threshold < Geom._minimal_surface[surface](x,y,z),materials[1],materials[0]),
|
||||
return Grid(material = np.where(threshold < Grid._minimal_surface[surface](x,y,z),materials[1],materials[0]),
|
||||
size = size,
|
||||
comments = util.execution_stamp('Geom','from_minimal_surface'),
|
||||
comments = util.execution_stamp('Grid','from_minimal_surface'),
|
||||
)
|
||||
|
||||
|
||||
|
@ -583,7 +583,7 @@ class Geom:
|
|||
Retain original materials within primitive and fill outside.
|
||||
Defaults to False.
|
||||
periodic : Boolean, optional
|
||||
Repeat primitive over boundaries. Defaults to True.
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
|
||||
"""
|
||||
# radius and center
|
||||
|
@ -604,23 +604,23 @@ class Geom:
|
|||
if periodic: # translate back to center
|
||||
mask = np.roll(mask,((c/self.size-0.5)*self.cells).round().astype(int),(0,1,2))
|
||||
|
||||
return Geom(material = np.where(np.logical_not(mask) if inverse else mask,
|
||||
return Grid(material = np.where(np.logical_not(mask) if inverse else mask,
|
||||
self.material,
|
||||
np.nanmax(self.material)+1 if fill is None else fill),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','add_primitive')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','add_primitive')],
|
||||
)
|
||||
|
||||
|
||||
def mirror(self,directions,reflect=False):
|
||||
"""
|
||||
Mirror geometry along given directions.
|
||||
Mirror grid along given directions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
directions : iterable containing str
|
||||
Direction(s) along which the geometry is mirrored.
|
||||
Direction(s) along which the grid is mirrored.
|
||||
Valid entries are 'x', 'y', 'z'.
|
||||
reflect : bool, optional
|
||||
Reflect (include) outermost layers. Defaults to False.
|
||||
|
@ -640,21 +640,21 @@ class Geom:
|
|||
if 'z' in directions:
|
||||
mat = np.concatenate([mat,mat[:,:,limits[0]:limits[1]:-1]],2)
|
||||
|
||||
return Geom(material = mat,
|
||||
return Grid(material = mat,
|
||||
size = self.size/self.cells*np.asarray(mat.shape),
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','mirror')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','mirror')],
|
||||
)
|
||||
|
||||
|
||||
def flip(self,directions):
|
||||
"""
|
||||
Flip geometry along given directions.
|
||||
Flip grid along given directions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
directions : iterable containing str
|
||||
Direction(s) along which the geometry is flipped.
|
||||
Direction(s) along which the grid is flipped.
|
||||
Valid entries are 'x', 'y', 'z'.
|
||||
|
||||
"""
|
||||
|
@ -664,26 +664,26 @@ class Geom:
|
|||
|
||||
mat = np.flip(self.material, (valid.index(d) for d in directions if d in valid))
|
||||
|
||||
return Geom(material = mat,
|
||||
return Grid(material = mat,
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','flip')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','flip')],
|
||||
)
|
||||
|
||||
|
||||
def scale(self,cells,periodic=True):
|
||||
"""
|
||||
Scale geometry to new cells.
|
||||
Scale grid to new cells.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cells : numpy.ndarray of shape (3)
|
||||
Number of cells in x,y,z direction.
|
||||
periodic : Boolean, optional
|
||||
Assume geometry to be periodic. Defaults to True.
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
|
||||
"""
|
||||
return Geom(material = ndimage.interpolation.zoom(
|
||||
return Grid(material = ndimage.interpolation.zoom(
|
||||
self.material,
|
||||
cells/self.cells,
|
||||
output=self.material.dtype,
|
||||
|
@ -693,13 +693,13 @@ class Geom:
|
|||
),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','scale')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','scale')],
|
||||
)
|
||||
|
||||
|
||||
def clean(self,stencil=3,selection=None,periodic=True):
|
||||
"""
|
||||
Smooth geometry by selecting most frequent material index within given stencil at each location.
|
||||
Smooth grid by selecting most frequent material index within given stencil at each location.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -708,7 +708,7 @@ class Geom:
|
|||
selection : list, optional
|
||||
Field values that can be altered. Defaults to all.
|
||||
periodic : Boolean, optional
|
||||
Assume geometry to be periodic. Defaults to True.
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
|
||||
"""
|
||||
def mostFrequent(arr,selection=None):
|
||||
|
@ -719,7 +719,7 @@ class Geom:
|
|||
else:
|
||||
return me
|
||||
|
||||
return Geom(material = ndimage.filters.generic_filter(
|
||||
return Grid(material = ndimage.filters.generic_filter(
|
||||
self.material,
|
||||
mostFrequent,
|
||||
size=(stencil if selection is None else stencil//2*2+1,)*3,
|
||||
|
@ -728,7 +728,7 @@ class Geom:
|
|||
).astype(self.material.dtype),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','clean')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','clean')],
|
||||
)
|
||||
|
||||
|
||||
|
@ -736,21 +736,21 @@ class Geom:
|
|||
"""Renumber sorted material indices as 0,...,N-1."""
|
||||
_,renumbered = np.unique(self.material,return_inverse=True)
|
||||
|
||||
return Geom(material = renumbered.reshape(self.cells),
|
||||
return Grid(material = renumbered.reshape(self.cells),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','renumber')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','renumber')],
|
||||
)
|
||||
|
||||
|
||||
def rotate(self,R,fill=None):
|
||||
"""
|
||||
Rotate geometry (pad if required).
|
||||
Rotate grid (pad if required).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
R : damask.Rotation
|
||||
Rotation to apply to the geometry.
|
||||
Rotation to apply to the grid.
|
||||
fill : int or float, optional
|
||||
Material index to fill the corners. Defaults to material.max() + 1.
|
||||
|
||||
|
@ -774,23 +774,23 @@ class Geom:
|
|||
|
||||
origin = self.origin-(np.asarray(material_in.shape)-self.cells)*.5 * self.size/self.cells
|
||||
|
||||
return Geom(material = material_in,
|
||||
return Grid(material = material_in,
|
||||
size = self.size/self.cells*np.asarray(material_in.shape),
|
||||
origin = origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','rotate')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','rotate')],
|
||||
)
|
||||
|
||||
|
||||
def canvas(self,cells=None,offset=None,fill=None):
|
||||
"""
|
||||
Crop or enlarge/pad geometry.
|
||||
Crop or enlarge/pad grid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cells : numpy.ndarray of shape (3)
|
||||
Number of cells x,y,z direction.
|
||||
offset : numpy.ndarray of shape (3)
|
||||
Offset (measured in cells) from old to new geometry [0,0,0].
|
||||
Offset (measured in cells) from old to new grid [0,0,0].
|
||||
fill : int or float, optional
|
||||
Material index to fill the background. Defaults to material.max() + 1.
|
||||
|
||||
|
@ -808,10 +808,10 @@ class Geom:
|
|||
|
||||
canvas[ll[0]:ur[0],ll[1]:ur[1],ll[2]:ur[2]] = self.material[LL[0]:UR[0],LL[1]:UR[1],LL[2]:UR[2]]
|
||||
|
||||
return Geom(material = canvas,
|
||||
return Grid(material = canvas,
|
||||
size = self.size/self.cells*np.asarray(canvas.shape),
|
||||
origin = self.origin+offset*self.size/self.cells,
|
||||
comments = self.comments+[util.execution_stamp('Geom','canvas')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','canvas')],
|
||||
)
|
||||
|
||||
|
||||
|
@ -833,10 +833,10 @@ class Geom:
|
|||
mp = np.vectorize(mp)
|
||||
mapper = dict(zip(from_material,to_material))
|
||||
|
||||
return Geom(material = mp(self.material,mapper).reshape(self.cells),
|
||||
return Grid(material = mp(self.material,mapper).reshape(self.cells),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','substitute')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','substitute')],
|
||||
)
|
||||
|
||||
|
||||
|
@ -847,10 +847,10 @@ class Geom:
|
|||
sort_idx = np.argsort(from_ma)
|
||||
ma = np.unique(a)[sort_idx][np.searchsorted(from_ma,a,sorter = sort_idx)]
|
||||
|
||||
return Geom(material = ma.reshape(self.cells,order='F'),
|
||||
return Grid(material = ma.reshape(self.cells,order='F'),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','sort')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','sort')],
|
||||
)
|
||||
|
||||
|
||||
|
@ -860,7 +860,7 @@ class Geom:
|
|||
|
||||
Different from themselves (or listed as triggers) within a given (cubic) vicinity,
|
||||
i.e. within the region close to a grain/phase boundary.
|
||||
ToDo: use include/exclude as in seeds.from_geom
|
||||
ToDo: use include/exclude as in seeds.from_grid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -874,7 +874,7 @@ class Geom:
|
|||
List of material indices that trigger a change.
|
||||
Defaults to [], meaning that any different neighbor triggers a change.
|
||||
periodic : Boolean, optional
|
||||
Assume geometry to be periodic. Defaults to True.
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
|
||||
"""
|
||||
def tainted_neighborhood(stencil,trigger):
|
||||
|
@ -891,10 +891,10 @@ class Geom:
|
|||
mode='wrap' if periodic else 'nearest',
|
||||
extra_keywords={'trigger':trigger})
|
||||
|
||||
return Geom(material = np.where(mask, self.material + offset_,self.material),
|
||||
return Grid(material = np.where(mask, self.material + offset_,self.material),
|
||||
size = self.size,
|
||||
origin = self.origin,
|
||||
comments = self.comments+[util.execution_stamp('Geom','vicinity_offset')],
|
||||
comments = self.comments+[util.execution_stamp('Grid','vicinity_offset')],
|
||||
)
|
||||
|
||||
|
||||
|
@ -904,10 +904,10 @@ class Geom:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
periodic : bool, optional
|
||||
Show boundaries across periodicity. Defaults to True.
|
||||
periodic : Boolean, optional
|
||||
Assume grid to be periodic. Defaults to True.
|
||||
directions : iterable containing str, optional
|
||||
Direction(s) along which the geometry is mirrored.
|
||||
Direction(s) along which the boundaries are determined.
|
||||
Valid entries are 'x', 'y', 'z'. Defaults to 'xyz'.
|
||||
|
||||
"""
|
|
@ -4,7 +4,7 @@ Filters for operations on regular grids.
|
|||
Notes
|
||||
-----
|
||||
The grids are defined as (x,y,z,...) where x is fastest and z is slowest.
|
||||
This convention is consistent with the geom file format.
|
||||
This convention is consistent with the layout in grid vtr files.
|
||||
When converting to/from a plain list (e.g. storage in ASCII table),
|
||||
the following operations are required for tensorial data:
|
||||
|
||||
|
|
|
@ -77,14 +77,14 @@ def from_Poisson_disc(size,N_seeds,N_candidates,distance,periodic=True,rng_seed=
|
|||
return coords
|
||||
|
||||
|
||||
def from_geom(geom,selection=None,invert=False,average=False,periodic=True):
|
||||
def from_grid(grid,selection=None,invert=False,average=False,periodic=True):
|
||||
"""
|
||||
Create seed from existing geometry description.
|
||||
Create seed from existing grid description.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : damask.Geom
|
||||
Geometry, from which the material IDs are used as seeds.
|
||||
grid : damask.Grid
|
||||
Grid, from which the material IDs are used as seeds.
|
||||
selection : iterable of integers, optional
|
||||
Material IDs to consider.
|
||||
invert : boolean, false
|
||||
|
@ -95,10 +95,10 @@ def from_geom(geom,selection=None,invert=False,average=False,periodic=True):
|
|||
Center of gravity with periodic boundaries.
|
||||
|
||||
"""
|
||||
material = geom.material.reshape((-1,1),order='F')
|
||||
mask = _np.full(geom.cells.prod(),True,dtype=bool) if selection is None else \
|
||||
material = grid.material.reshape((-1,1),order='F')
|
||||
mask = _np.full(grid.cells.prod(),True,dtype=bool) if selection is None else \
|
||||
_np.isin(material,selection,invert=invert).flatten()
|
||||
coords = grid_filters.coordinates0_point(geom.cells,geom.size).reshape(-1,3,order='F')
|
||||
coords = grid_filters.coordinates0_point(grid.cells,grid.size).reshape(-1,3,order='F')
|
||||
|
||||
if not average:
|
||||
return (coords[mask],material[mask])
|
||||
|
@ -106,8 +106,8 @@ def from_geom(geom,selection=None,invert=False,average=False,periodic=True):
|
|||
materials = _np.unique(material[mask])
|
||||
coords_ = _np.zeros((materials.size,3),dtype=float)
|
||||
for i,mat in enumerate(materials):
|
||||
pc = (2*_np.pi*coords[material[:,0]==mat,:]-geom.origin)/geom.size
|
||||
coords_[i] = geom.origin + geom.size / 2 / _np.pi * (_np.pi +
|
||||
pc = (2*_np.pi*coords[material[:,0]==mat,:]-grid.origin)/grid.size
|
||||
coords_[i] = grid.origin + grid.size / 2 / _np.pi * (_np.pi +
|
||||
_np.arctan2(-_np.average(_np.sin(pc),axis=0),
|
||||
-_np.average(_np.cos(pc),axis=0))) \
|
||||
if periodic else \
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
import numpy as np
|
||||
|
||||
from damask import VTK
|
||||
from damask import Geom
|
||||
from damask import Grid
|
||||
from damask import Table
|
||||
from damask import Rotation
|
||||
from damask import util
|
||||
|
@ -10,7 +10,7 @@ from damask import seeds
|
|||
from damask import grid_filters
|
||||
|
||||
|
||||
def geom_equal(a,b):
|
||||
def grid_equal(a,b):
|
||||
return np.all(a.material == b.material) and \
|
||||
np.all(a.cells == b.cells) and \
|
||||
np.allclose(a.size, b.size) and \
|
||||
|
@ -23,15 +23,15 @@ def default():
|
|||
np.arange(2,42),
|
||||
np.ones(40,dtype=int)*2,
|
||||
np.arange(1,41))).reshape(8,5,4,order='F')
|
||||
return Geom(x,[8e-6,5e-6,4e-6])
|
||||
return Grid(x,[8e-6,5e-6,4e-6])
|
||||
|
||||
@pytest.fixture
|
||||
def ref_path(ref_path_base):
|
||||
"""Directory containing reference results."""
|
||||
return ref_path_base/'Geom'
|
||||
return ref_path_base/'Grid'
|
||||
|
||||
|
||||
class TestGeom:
|
||||
class TestGrid:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_execution_stamp(self, patch_execution_stamp):
|
||||
|
@ -46,7 +46,7 @@ class TestGeom:
|
|||
|
||||
|
||||
def test_diff_not_equal(self,default):
|
||||
new = Geom(default.material[1:,1:,1:]+1,default.size*.9,np.ones(3)-default.origin,comments=['modified'])
|
||||
new = Grid(default.material[1:,1:,1:]+1,default.size*.9,np.ones(3)-default.origin,comments=['modified'])
|
||||
assert str(default.diff(new)) != ''
|
||||
|
||||
def test_repr(self,default):
|
||||
|
@ -54,36 +54,36 @@ class TestGeom:
|
|||
|
||||
def test_read_write_vtr(self,default,tmp_path):
|
||||
default.save(tmp_path/'default')
|
||||
new = Geom.load(tmp_path/'default.vtr')
|
||||
assert geom_equal(new,default)
|
||||
new = Grid.load(tmp_path/'default.vtr')
|
||||
assert grid_equal(new,default)
|
||||
|
||||
def test_invalid_vtr(self,tmp_path):
|
||||
v = VTK.from_rectilinear_grid(np.random.randint(5,10,3)*2,np.random.random(3) + 1.0)
|
||||
v.save(tmp_path/'no_materialpoint.vtr',parallel=False)
|
||||
with pytest.raises(ValueError):
|
||||
Geom.load(tmp_path/'no_materialpoint.vtr')
|
||||
Grid.load(tmp_path/'no_materialpoint.vtr')
|
||||
|
||||
def test_invalid_material(self):
|
||||
with pytest.raises(TypeError):
|
||||
Geom(np.zeros((3,3,3),dtype='complex'),np.ones(3))
|
||||
Grid(np.zeros((3,3,3),dtype='complex'),np.ones(3))
|
||||
|
||||
def test_cast_to_int(self):
|
||||
g = Geom(np.zeros((3,3,3)),np.ones(3))
|
||||
g = Grid(np.zeros((3,3,3)),np.ones(3))
|
||||
assert g.material.dtype in np.sctypes['int']
|
||||
|
||||
def test_invalid_size(self,default):
|
||||
with pytest.raises(ValueError):
|
||||
Geom(default.material[1:,1:,1:],
|
||||
Grid(default.material[1:,1:,1:],
|
||||
size=np.ones(2))
|
||||
|
||||
def test_save_load_ASCII(self,default,tmp_path):
|
||||
default.save_ASCII(tmp_path/'ASCII')
|
||||
default.material -= 1
|
||||
assert geom_equal(Geom.load_ASCII(tmp_path/'ASCII'),default)
|
||||
assert grid_equal(Grid.load_ASCII(tmp_path/'ASCII'),default)
|
||||
|
||||
def test_invalid_origin(self,default):
|
||||
with pytest.raises(ValueError):
|
||||
Geom(default.material[1:,1:,1:],
|
||||
Grid(default.material[1:,1:,1:],
|
||||
size=np.ones(3),
|
||||
origin=np.ones(4))
|
||||
|
||||
|
@ -91,14 +91,14 @@ class TestGeom:
|
|||
def test_invalid_materials_shape(self,default):
|
||||
material = np.ones((3,3))
|
||||
with pytest.raises(ValueError):
|
||||
Geom(material,
|
||||
Grid(material,
|
||||
size=np.ones(3))
|
||||
|
||||
|
||||
def test_invalid_materials_type(self,default):
|
||||
material = np.random.randint(1,300,(3,4,5))==1
|
||||
with pytest.raises(TypeError):
|
||||
Geom(material)
|
||||
Grid(material)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('directions,reflect',[
|
||||
|
@ -113,7 +113,7 @@ class TestGeom:
|
|||
tag = f'directions_{"-".join(directions)}+reflect_{reflect}'
|
||||
reference = ref_path/f'mirror_{tag}.vtr'
|
||||
if update: modified.save(reference)
|
||||
assert geom_equal(Geom.load(reference),
|
||||
assert grid_equal(Grid.load(reference),
|
||||
modified)
|
||||
|
||||
|
||||
|
@ -135,17 +135,17 @@ class TestGeom:
|
|||
tag = f'directions_{"-".join(directions)}'
|
||||
reference = ref_path/f'flip_{tag}.vtr'
|
||||
if update: modified.save(reference)
|
||||
assert geom_equal(Geom.load(reference),
|
||||
assert grid_equal(Grid.load(reference),
|
||||
modified)
|
||||
|
||||
|
||||
def test_flip_invariant(self,default):
|
||||
assert geom_equal(default,default.flip([]))
|
||||
assert grid_equal(default,default.flip([]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('direction',[['x'],['x','y']])
|
||||
def test_flip_double(self,default,direction):
|
||||
assert geom_equal(default,default.flip(direction).flip(direction))
|
||||
assert grid_equal(default,default.flip(direction).flip(direction))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('directions',[(1,2,'y'),('a','b','x'),[1]])
|
||||
|
@ -162,7 +162,7 @@ class TestGeom:
|
|||
reference = ref_path/f'clean_{stencil}_{"+".join(map(str,[None] if selection is None else selection))}_{periodic}'
|
||||
if update and stencil > 1:
|
||||
current.save(reference)
|
||||
assert geom_equal(Geom.load(reference) if stencil > 1 else default,
|
||||
assert grid_equal(Grid.load(reference) if stencil > 1 else default,
|
||||
current
|
||||
)
|
||||
|
||||
|
@ -181,7 +181,7 @@ class TestGeom:
|
|||
tag = f'grid_{util.srepr(cells,"-")}'
|
||||
reference = ref_path/f'scale_{tag}.vtr'
|
||||
if update: modified.save(reference)
|
||||
assert geom_equal(Geom.load(reference),
|
||||
assert grid_equal(Grid.load(reference),
|
||||
modified)
|
||||
|
||||
|
||||
|
@ -190,21 +190,21 @@ class TestGeom:
|
|||
for m in np.unique(material):
|
||||
material[material==m] = material.max() + np.random.randint(1,30)
|
||||
default.material -= 1
|
||||
modified = Geom(material,
|
||||
modified = Grid(material,
|
||||
default.size,
|
||||
default.origin)
|
||||
assert not geom_equal(modified,default)
|
||||
assert geom_equal(default,
|
||||
assert not grid_equal(modified,default)
|
||||
assert grid_equal(default,
|
||||
modified.renumber())
|
||||
|
||||
|
||||
def test_substitute(self,default):
|
||||
offset = np.random.randint(1,500)
|
||||
modified = Geom(default.material + offset,
|
||||
modified = Grid(default.material + offset,
|
||||
default.size,
|
||||
default.origin)
|
||||
assert not geom_equal(modified,default)
|
||||
assert geom_equal(default,
|
||||
assert not grid_equal(modified,default)
|
||||
assert grid_equal(default,
|
||||
modified.substitute(np.arange(default.material.max())+1+offset,
|
||||
np.arange(default.material.max())+1))
|
||||
|
||||
|
@ -212,12 +212,12 @@ class TestGeom:
|
|||
f = np.unique(default.material.flatten())[:np.random.randint(1,default.material.max())]
|
||||
t = np.random.permutation(f)
|
||||
modified = default.substitute(f,t)
|
||||
assert np.array_equiv(t,f) or (not geom_equal(modified,default))
|
||||
assert geom_equal(default, modified.substitute(t,f))
|
||||
assert np.array_equiv(t,f) or (not grid_equal(modified,default))
|
||||
assert grid_equal(default, modified.substitute(t,f))
|
||||
|
||||
def test_sort(self):
|
||||
cells = np.random.randint(5,20,3)
|
||||
m = Geom(np.random.randint(1,20,cells)*3,np.ones(3)).sort().material.flatten(order='F')
|
||||
m = Grid(np.random.randint(1,20,cells)*3,np.ones(3)).sort().material.flatten(order='F')
|
||||
for i,v in enumerate(m):
|
||||
assert i==0 or v > m[:i].max() or v in m[:i]
|
||||
|
||||
|
@ -227,7 +227,7 @@ class TestGeom:
|
|||
modified = default.copy()
|
||||
for i in range(np.rint(360/axis_angle[3]).astype(int)):
|
||||
modified.rotate(Rotation.from_axis_angle(axis_angle,degrees=True))
|
||||
assert geom_equal(default,modified)
|
||||
assert grid_equal(default,modified)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Eulers',[[32.0,68.0,21.0],
|
||||
|
@ -237,7 +237,7 @@ class TestGeom:
|
|||
tag = f'Eulers_{util.srepr(Eulers,"-")}'
|
||||
reference = ref_path/f'rotate_{tag}.vtr'
|
||||
if update: modified.save(reference)
|
||||
assert geom_equal(Geom.load(reference),
|
||||
assert grid_equal(Grid.load(reference),
|
||||
modified)
|
||||
|
||||
|
||||
|
@ -263,8 +263,8 @@ class TestGeom:
|
|||
o = np.random.random(3)-.5
|
||||
g = np.random.randint(8,32,(3))
|
||||
s = np.random.random(3)+.5
|
||||
G_1 = Geom(np.ones(g,'i'),s,o).add_primitive(diameter,center1,exponent)
|
||||
G_2 = Geom(np.ones(g,'i'),s,o).add_primitive(diameter,center2,exponent)
|
||||
G_1 = Grid(np.ones(g,'i'),s,o).add_primitive(diameter,center1,exponent)
|
||||
G_2 = Grid(np.ones(g,'i'),s,o).add_primitive(diameter,center2,exponent)
|
||||
assert np.count_nonzero(G_1.material!=2) == np.count_nonzero(G_2.material!=2)
|
||||
|
||||
|
||||
|
@ -279,9 +279,9 @@ class TestGeom:
|
|||
g = np.random.randint(8,32,(3))
|
||||
s = np.random.random(3)+.5
|
||||
fill = np.random.randint(10)+2
|
||||
G_1 = Geom(np.ones(g,'i'),s).add_primitive(.3,center,1,fill,inverse=inverse,periodic=periodic)
|
||||
G_2 = Geom(np.ones(g,'i'),s).add_primitive(.3,center,1,fill,Rotation.from_random(),inverse,periodic=periodic)
|
||||
assert geom_equal(G_1,G_2)
|
||||
G_1 = Grid(np.ones(g,'i'),s).add_primitive(.3,center,1,fill,inverse=inverse,periodic=periodic)
|
||||
G_2 = Grid(np.ones(g,'i'),s).add_primitive(.3,center,1,fill,Rotation.from_random(),inverse,periodic=periodic)
|
||||
assert grid_equal(G_1,G_2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('trigger',[[1],[]])
|
||||
|
@ -300,9 +300,9 @@ class TestGeom:
|
|||
if len(trigger) > 0:
|
||||
m2[m==1] = 1
|
||||
|
||||
geom = Geom(m,np.random.rand(3)).vicinity_offset(vicinity,offset,trigger=trigger)
|
||||
grid = Grid(m,np.random.rand(3)).vicinity_offset(vicinity,offset,trigger=trigger)
|
||||
|
||||
assert np.all(m2==geom.material)
|
||||
assert np.all(m2==grid.material)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('periodic',[True,False])
|
||||
|
@ -318,9 +318,9 @@ class TestGeom:
|
|||
size = np.random.random(3) + 1.0
|
||||
N_seeds= np.random.randint(10,30)
|
||||
seeds = np.random.rand(N_seeds,3) * np.broadcast_to(size,(N_seeds,3))
|
||||
Voronoi = Geom.from_Voronoi_tessellation( cells,size,seeds, np.arange(N_seeds)+5,periodic)
|
||||
Laguerre = Geom.from_Laguerre_tessellation(cells,size,seeds,np.ones(N_seeds),np.arange(N_seeds)+5,periodic)
|
||||
assert geom_equal(Laguerre,Voronoi)
|
||||
Voronoi = Grid.from_Voronoi_tessellation( cells,size,seeds, np.arange(N_seeds)+5,periodic)
|
||||
Laguerre = Grid.from_Laguerre_tessellation(cells,size,seeds,np.ones(N_seeds),np.arange(N_seeds)+5,periodic)
|
||||
assert grid_equal(Laguerre,Voronoi)
|
||||
|
||||
|
||||
def test_Laguerre_weights(self):
|
||||
|
@ -331,7 +331,7 @@ class TestGeom:
|
|||
weights= np.full((N_seeds),-np.inf)
|
||||
ms = np.random.randint(N_seeds)
|
||||
weights[ms] = np.random.random()
|
||||
Laguerre = Geom.from_Laguerre_tessellation(cells,size,seeds,weights,periodic=np.random.random()>0.5)
|
||||
Laguerre = Grid.from_Laguerre_tessellation(cells,size,seeds,weights,periodic=np.random.random()>0.5)
|
||||
assert np.all(Laguerre.material == ms)
|
||||
|
||||
|
||||
|
@ -343,10 +343,10 @@ class TestGeom:
|
|||
material = np.zeros(cells)
|
||||
material[:,cells[1]//2:,:] = 1
|
||||
if approach == 'Laguerre':
|
||||
geom = Geom.from_Laguerre_tessellation(cells,size,seeds,np.ones(2),periodic=np.random.random()>0.5)
|
||||
grid = Grid.from_Laguerre_tessellation(cells,size,seeds,np.ones(2),periodic=np.random.random()>0.5)
|
||||
elif approach == 'Voronoi':
|
||||
geom = Geom.from_Voronoi_tessellation(cells,size,seeds, periodic=np.random.random()>0.5)
|
||||
assert np.all(geom.material == material)
|
||||
grid = Grid.from_Voronoi_tessellation(cells,size,seeds, periodic=np.random.random()>0.5)
|
||||
assert np.all(grid.material == material)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('surface',['Schwarz P',
|
||||
|
@ -368,9 +368,9 @@ class TestGeom:
|
|||
threshold = 2*np.random.rand()-1.
|
||||
periods = np.random.randint(2)+1
|
||||
materials = np.random.randint(0,40,2)
|
||||
geom = Geom.from_minimal_surface(cells,size,surface,threshold,periods,materials)
|
||||
assert set(geom.material.flatten()) | set(materials) == set(materials) \
|
||||
and (geom.size == size).all() and (geom.cells == cells).all()
|
||||
grid = Grid.from_minimal_surface(cells,size,surface,threshold,periods,materials)
|
||||
assert set(grid.material.flatten()) | set(materials) == set(materials) \
|
||||
and (grid.size == size).all() and (grid.cells == cells).all()
|
||||
|
||||
@pytest.mark.parametrize('surface,threshold',[('Schwarz P',0),
|
||||
('Double Primitive',-1./6.),
|
||||
|
@ -387,8 +387,8 @@ class TestGeom:
|
|||
])
|
||||
def test_minimal_surface_volume(self,surface,threshold):
|
||||
cells = np.ones(3,dtype=int)*64
|
||||
geom = Geom.from_minimal_surface(cells,np.ones(3),surface,threshold)
|
||||
assert np.isclose(np.count_nonzero(geom.material==1)/np.prod(geom.cells),.5,rtol=1e-3)
|
||||
grid = Grid.from_minimal_surface(cells,np.ones(3),surface,threshold)
|
||||
assert np.isclose(np.count_nonzero(grid.material==1)/np.prod(grid.cells),.5,rtol=1e-3)
|
||||
|
||||
|
||||
def test_from_table(self):
|
||||
|
@ -398,7 +398,7 @@ class TestGeom:
|
|||
z=np.ones(cells.prod())
|
||||
z[cells[:2].prod()*int(cells[2]/2):]=0
|
||||
t = Table(np.column_stack((coords,z)),{'coords':3,'z':1})
|
||||
g = Geom.from_table(t,'coords',['1_coords','z'])
|
||||
g = Grid.from_table(t,'coords',['1_coords','z'])
|
||||
assert g.N_materials == g.cells[0]*2 and (g.material[:,:,-1]-g.material[:,:,0] == cells[0]).all()
|
||||
|
||||
|
||||
|
@ -406,16 +406,16 @@ class TestGeom:
|
|||
cells = np.random.randint(60,100,3)
|
||||
size = np.ones(3)+np.random.rand(3)
|
||||
s = seeds.from_random(size,np.random.randint(60,100))
|
||||
geom = Geom.from_Voronoi_tessellation(cells,size,s)
|
||||
grid = Grid.from_Voronoi_tessellation(cells,size,s)
|
||||
coords = grid_filters.coordinates0_point(cells,size)
|
||||
t = Table(np.column_stack((coords.reshape(-1,3,order='F'),geom.material.flatten(order='F'))),{'c':3,'m':1})
|
||||
assert geom_equal(geom.sort().renumber(),Geom.from_table(t,'c',['m']))
|
||||
t = Table(np.column_stack((coords.reshape(-1,3,order='F'),grid.material.flatten(order='F'))),{'c':3,'m':1})
|
||||
assert grid_equal(grid.sort().renumber(),Grid.from_table(t,'c',['m']))
|
||||
|
||||
@pytest.mark.parametrize('periodic',[True,False])
|
||||
@pytest.mark.parametrize('direction',['x','y','z',['x','y'],'zy','xz',['x','y','z']])
|
||||
def test_get_grain_boundaries(self,update,ref_path,periodic,direction):
|
||||
geom=Geom.load(ref_path/'get_grain_boundaries_8g12x15x20.vtr')
|
||||
current=geom.get_grain_boundaries(periodic,direction)
|
||||
grid=Grid.load(ref_path/'get_grain_boundaries_8g12x15x20.vtr')
|
||||
current=grid.get_grain_boundaries(periodic,direction)
|
||||
if update:
|
||||
current.save(ref_path/f'get_grain_boundaries_8g12x15x20_{direction}_{periodic}.vtu',parallel=False)
|
||||
reference=VTK.load(ref_path/f'get_grain_boundaries_8g12x15x20_{"".join(direction)}_{periodic}.vtu')
|
|
@ -4,7 +4,7 @@ from scipy.spatial import cKDTree
|
|||
|
||||
from damask import seeds
|
||||
from damask import grid_filters
|
||||
from damask import Geom
|
||||
from damask import Grid
|
||||
|
||||
class TestSeeds:
|
||||
|
||||
|
@ -26,37 +26,37 @@ class TestSeeds:
|
|||
cKDTree(coords).query(coords, 2)
|
||||
assert (0<= coords).all() and (coords<size).all() and np.min(min_dists[:,1])>=distance
|
||||
|
||||
def test_from_geom_reconstruct(self):
|
||||
def test_from_grid_reconstruct(self):
|
||||
cells = np.random.randint(10,20,3)
|
||||
N_seeds = np.random.randint(30,300)
|
||||
size = np.ones(3) + np.random.random(3)
|
||||
coords = seeds.from_random(size,N_seeds,cells)
|
||||
geom_1 = Geom.from_Voronoi_tessellation(cells,size,coords)
|
||||
coords,material = seeds.from_geom(geom_1)
|
||||
geom_2 = Geom.from_Voronoi_tessellation(cells,size,coords,material)
|
||||
assert (geom_2.material==geom_1.material).all()
|
||||
grid_1 = Grid.from_Voronoi_tessellation(cells,size,coords)
|
||||
coords,material = seeds.from_grid(grid_1)
|
||||
grid_2 = Grid.from_Voronoi_tessellation(cells,size,coords,material)
|
||||
assert (grid_2.material==grid_1.material).all()
|
||||
|
||||
@pytest.mark.parametrize('periodic',[True,False])
|
||||
@pytest.mark.parametrize('average',[True,False])
|
||||
def test_from_geom_grid(self,periodic,average):
|
||||
def test_from_grid_grid(self,periodic,average):
|
||||
cells = np.random.randint(10,20,3)
|
||||
size = np.ones(3) + np.random.random(3)
|
||||
coords = grid_filters.coordinates0_point(cells,size).reshape(-1,3)
|
||||
np.random.shuffle(coords)
|
||||
geom_1 = Geom.from_Voronoi_tessellation(cells,size,coords)
|
||||
coords,material = seeds.from_geom(geom_1,average=average,periodic=periodic)
|
||||
geom_2 = Geom.from_Voronoi_tessellation(cells,size,coords,material)
|
||||
assert (geom_2.material==geom_1.material).all()
|
||||
grid_1 = Grid.from_Voronoi_tessellation(cells,size,coords)
|
||||
coords,material = seeds.from_grid(grid_1,average=average,periodic=periodic)
|
||||
grid_2 = Grid.from_Voronoi_tessellation(cells,size,coords,material)
|
||||
assert (grid_2.material==grid_1.material).all()
|
||||
|
||||
@pytest.mark.parametrize('periodic',[True,False])
|
||||
@pytest.mark.parametrize('average',[True,False])
|
||||
@pytest.mark.parametrize('invert',[True,False])
|
||||
def test_from_geom_selection(self,periodic,average,invert):
|
||||
def test_from_grid_selection(self,periodic,average,invert):
|
||||
cells = np.random.randint(10,20,3)
|
||||
N_seeds = np.random.randint(30,300)
|
||||
size = np.ones(3) + np.random.random(3)
|
||||
coords = seeds.from_random(size,N_seeds,cells)
|
||||
geom = Geom.from_Voronoi_tessellation(cells,size,coords)
|
||||
grid = Grid.from_Voronoi_tessellation(cells,size,coords)
|
||||
selection=np.random.randint(N_seeds)+1
|
||||
coords,material = seeds.from_geom(geom,average=average,periodic=periodic,invert=invert,selection=[selection])
|
||||
coords,material = seeds.from_grid(grid,average=average,periodic=periodic,invert=invert,selection=[selection])
|
||||
assert selection not in material if invert else (selection==material).all()
|
||||
|
|
Loading…
Reference in New Issue