Table.save(with_labels=False) to properly store ANG

This commit is contained in:
Philip Eisenlohr 2022-03-17 21:23:57 -04:00
parent 1147b5b742
commit 21e076b0f9
2 changed files with 23 additions and 13 deletions

View File

@ -569,7 +569,8 @@ class Table:
def save(self, def save(self,
fname: FileHandle): fname: FileHandle,
with_labels: bool = True):
""" """
Save as plain text file. Save as plain text file.
@ -577,20 +578,23 @@ class Table:
---------- ----------
fname : file, str, or pathlib.Path fname : file, str, or pathlib.Path
Filename or file for writing. Filename or file for writing.
with_labels : bool, optional
Write column labels. Defaults to True.
""" """
labels = [] labels = []
for l in list(dict.fromkeys(self.data.columns)): if with_labels:
if self.shapes[l] == (1,): for l in list(dict.fromkeys(self.data.columns)):
labels.append(f'{l}') if self.shapes[l] == (1,):
elif len(self.shapes[l]) == 1: labels.append(f'{l}')
labels += [f'{i+1}_{l}' \ elif len(self.shapes[l]) == 1:
for i in range(self.shapes[l][0])] labels += [f'{i+1}_{l}' \
else: for i in range(self.shapes[l][0])]
labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \ else:
for i in range(np.prod(self.shapes[l]))] labels += [f'{util.srepr(self.shapes[l],"x")}:{i+1}_{l}' \
for i in range(np.prod(self.shapes[l]))]
f = open(Path(fname).expanduser(),'w',newline='\n') if isinstance(fname, (str, Path)) else fname f = open(Path(fname).expanduser(),'w',newline='\n') if isinstance(fname, (str, Path)) else fname
f.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+'\n') f.write('\n'.join([f'# {c}' for c in self.comments] + [' '.join(labels)])+('\n' if labels else ''))
self.data.to_csv(f,sep=' ',na_rep='nan',index=False,header=False) self.data.to_csv(f,sep=' ',na_rep='nan',index=False,header=False)

View File

@ -77,14 +77,14 @@ class TestTable:
new = Table.load(tmp_path/'default.txt') new = Table.load(tmp_path/'default.txt')
elif mode == 'str': elif mode == 'str':
new = Table.load(str(tmp_path/'default.txt')) new = Table.load(str(tmp_path/'default.txt'))
assert all(default.data==new.data) and default.shapes == new.shapes assert all(default.data == new.data) and default.shapes == new.shapes
def test_write_read_file(self,default,tmp_path): def test_write_read_file(self,default,tmp_path):
with open(tmp_path/'default.txt','w') as f: with open(tmp_path/'default.txt','w') as f:
default.save(f) default.save(f)
with open(tmp_path/'default.txt') as f: with open(tmp_path/'default.txt') as f:
new = Table.load(f) new = Table.load(f)
assert all(default.data==new.data) and default.shapes == new.shapes assert all(default.data == new.data) and default.shapes == new.shapes
def test_write_invalid_format(self,default,tmp_path): def test_write_invalid_format(self,default,tmp_path):
with pytest.raises(TypeError): with pytest.raises(TypeError):
@ -105,6 +105,12 @@ class TestTable:
assert new.data.shape == (4,10) and \ assert new.data.shape == (4,10) and \
new.labels == ['eu', 'pos', 'IQ', 'CI', 'ID', 'intensity', 'fit'] new.labels == ['eu', 'pos', 'IQ', 'CI', 'ID', 'intensity', 'fit']
def test_save_ang(self,ref_path,tmp_path):
orig = Table.load_ang(ref_path/'simple.ang')
orig.save(tmp_path/'simple.ang',with_labels=False)
saved = Table.load_ang(tmp_path/'simple.ang')
assert saved == orig
@pytest.mark.parametrize('fname',['datatype-mix.txt','whitespace-mix.txt']) @pytest.mark.parametrize('fname',['datatype-mix.txt','whitespace-mix.txt'])
def test_read_strange(self,ref_path,fname): def test_read_strange(self,ref_path,fname):
with open(ref_path/fname) as f: with open(ref_path/fname) as f: