diff --git a/magpylib_material_response/demag.py b/magpylib_material_response/demag.py index 4fcade7..e1836bc 100644 --- a/magpylib_material_response/demag.py +++ b/magpylib_material_response/demag.py @@ -31,25 +31,60 @@ logger.configure(**config) -def get_susceptibilities(*sources, susceptibility=None): +def get_susceptibilities(*sources, susceptibility): """Return a list of length (len(sources)) with susceptibility values Priority is given at the source level, hovever if value is not found, it is searched up the parent tree, if available. Raises an error if no value is found when reached the top level of the tree.""" - susceptibilities = [] - for src in sources: - susceptibility = getattr(src, "susceptibility", None) - if susceptibility is None: - if src.parent is None: - raise ValueError("No susceptibility defined in any parent collection") - susceptibilities.extend(get_susceptibilities(src.parent)) - elif not hasattr(susceptibility, "__len__"): - susceptibilities.append((susceptibility, susceptibility, susceptibility)) - elif len(susceptibility) == 3: - susceptibilities.append(susceptibility) - else: - raise ValueError("susceptibility is not scalar or array fo length 3") - return susceptibilities + + # susceptibilities from source attributes + if susceptibility is None: + susceptibilities = [] + for src in sources: + susceptibility = getattr(src, "susceptibility", None) + if susceptibility is None: + if src.parent is None: + raise ValueError("No susceptibility defined in any parent collection") + susceptibilities.extend(get_susceptibilities(src.parent)) + elif not hasattr(susceptibility, "__len__"): + susceptibilities.append((susceptibility, susceptibility, susceptibility)) + elif len(susceptibility) == 3: + susceptibilities.append(susceptibility) + else: + raise ValueError("susceptibility is not scalar or array fo length 3") + return susceptibilities + + # susceptibilities as input to demag function + n = len(sources) + if np.isscalar(susceptibility): + susceptibility = np.ones((n,3))*susceptibility + elif len(susceptibility) == 3: + susceptibility = np.tile(susceptibility, (n,1)) + if n==3: + raise ValueError( + "Apply_demag input susceptibility is ambiguous - either scalar list or vector single entry. " + "Please choose different means of input or change the number of cells in the Collection." + ) + else: + if len(susceptibility) != n: + raise ValueError( + "Apply_demag input susceptibility must be scalar, 3-vector, or same length as input Collection." + ) + susceptibility = np.array(susceptibility) + if susceptibility.ndim == 1: + susceptibility = np.repeat(susceptibility,3).reshape(n,3) + + susceptibility = np.reshape(susceptibility, 3 * n, order="F") + + + + + + + + return np.array(susceptibilities) + + def get_H_ext(*sources, H_ext=None): @@ -364,15 +399,8 @@ def apply_demag( ) # shape ii = x1, ... xn, y1, ... yn, z1, ... zn # set up S - if susceptibility is None: - susceptibility = get_susceptibilities(*magnets_list) - susceptibility = np.array(susceptibility) - if len(susceptibility) != n: - raise ValueError( - "Apply_demag input collection and susceptibility must have same length." - ) - susceptibility = np.reshape(susceptibility, 3 * n, order="F") - S = np.diag(susceptibility) # shape ii, jj + sus = get_susceptibilities(magnets_list, susceptibility) + S = np.diag(sus) # shape ii, jj # set up H_ext H_ext = get_H_ext(*magnets_list)