diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/base/DetectorDescriptor.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/base/DetectorDescriptor.java index 81d12615d..058916d69 100644 --- a/common-tools/clas-detector/src/main/java/org/jlab/detector/base/DetectorDescriptor.java +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/base/DetectorDescriptor.java @@ -64,6 +64,14 @@ public void setOrder(int order){ } } + public int[] getCSC() { + return new int[]{hw_CRATE,hw_SLOT,hw_CHANNEL}; + } + + public int[] getSLCO() { + return new int[]{dt_SECTOR,dt_LAYER,dt_COMPONENT,dt_ORDER}; + } + public DetectorType getType(){ return this.detectorType;} public final void setType(DetectorType type){ diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/decode/CLASDecoder4.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/decode/CLASDecoder4.java index 289c29c59..17fee2548 100644 --- a/common-tools/clas-detector/src/main/java/org/jlab/detector/decode/CLASDecoder4.java +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/decode/CLASDecoder4.java @@ -14,6 +14,7 @@ import org.jlab.detector.helicity.HelicityBit; import org.jlab.detector.helicity.HelicitySequence; import org.jlab.detector.helicity.HelicityState; +import org.jlab.detector.pulse.ModeAHDC; import org.jlab.logging.DefaultLogger; @@ -46,7 +47,8 @@ public class CLASDecoder4 { private HipoDataEvent hipoEvent = null; private boolean isRunNumberFixed = false; private int decoderDebugMode = 0; - private SchemaFactory schemaFactory = new SchemaFactory(); + private SchemaFactory schemaFactory = new SchemaFactory(); + private ModeAHDC ahdcExtractor = new ModeAHDC(); public CLASDecoder4(boolean development){ codaDecoder = new CodaEventDecoder(); @@ -245,6 +247,26 @@ public List getEntriesSCALER(DetectorType type, return scaler; } + public void extractPulses(Event event) { + ahdcExtractor.update(6, null, event, schemaFactory, "AHDC::wf", "AHDC::adc"); + } + + public Bank getDataBankWF(String name, DetectorType type) { + List a = this.getEntriesADC(type); + Bank b = new Bank(schemaFactory.getSchema(name), a.size()); + for (int i=0; i adcDGTZ = this.getEntriesADC(type); @@ -420,20 +442,20 @@ public Event getDataEvent(){ Event event = new Event(); + String[] wfBankNames = new String[]{"AHDC::wf"}; + DetectorType[] wfBankTypes = new DetectorType[]{DetectorType.AHDC}; String[] adcBankNames = new String[]{"FTOF::adc","ECAL::adc","FTCAL::adc", "FTHODO::adc", "FTTRK::adc", "HTCC::adc","BST::adc","CTOF::adc", "CND::adc","LTCC::adc","BMT::adc", "FMT::adc","HEL::adc","RF::adc", - "BAND::adc","RASTER::adc", - "AHDC::adc"}; + "BAND::adc","RASTER::adc"}; DetectorType[] adcBankTypes = new DetectorType[]{DetectorType.FTOF,DetectorType.ECAL,DetectorType.FTCAL, DetectorType.FTHODO,DetectorType.FTTRK, DetectorType.HTCC,DetectorType.BST,DetectorType.CTOF, DetectorType.CND,DetectorType.LTCC,DetectorType.BMT, DetectorType.FMT,DetectorType.HEL,DetectorType.RF, - DetectorType.BAND, DetectorType.RASTER, - DetectorType.AHDC}; + DetectorType.BAND, DetectorType.RASTER}; String[] tdcBankNames = new String[]{"FTOF::tdc","ECAL::tdc","DC::tdc", "HTCC::tdc","LTCC::tdc","CTOF::tdc", @@ -453,6 +475,13 @@ public Event getDataEvent(){ } } + for(int i = 0; i < wfBankTypes.length; i++){ + Bank wfBank = getDataBankWF(wfBankNames[i],wfBankTypes[i]); + if(wfBank!=null && wfBank.getRows()>0){ + event.write(wfBank); + } + } + for(int i = 0; i < tdcBankTypes.length; i++){ Bank tdcBank = getDataBankTDC(tdcBankNames[i],tdcBankTypes[i]); if(tdcBank!=null){ @@ -825,7 +854,9 @@ public static void main(String[] args){ decodedEvent.read(rawScaler); decodedEvent.read(rawRunConf); decodedEvent.read(helicityAdc); - + + decoder.extractPulses(decodedEvent); + helicityReadings.add(HelicityState.createFromFadcBank(helicityAdc, rawRunConf, decoder.detectorDecoder.scalerManager)); @@ -867,4 +898,5 @@ public static void main(String[] args){ writer.close(); } + } diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/HipoExtractor.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/HipoExtractor.java new file mode 100644 index 000000000..fba8012e4 --- /dev/null +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/HipoExtractor.java @@ -0,0 +1,160 @@ + +package org.jlab.detector.pulse; + +import java.util.ArrayList; +import java.util.List; +import org.jlab.io.base.DataBank; +import org.jlab.io.base.DataEvent; +import org.jlab.jnp.hipo4.data.Bank; +import org.jlab.jnp.hipo4.data.Event; +import org.jlab.jnp.hipo4.data.SchemaFactory; +import org.jlab.utils.groups.IndexedTable; + +/** + * For now, a place to store standard boilerplate for waveform/pulse HIPO + * manipulations. No bounds checking regarding number of samples. + * + * Here an IndexedTable object from CCDB is used to pass initialization parameters + * to the extractor. If that object is null, the @{link org.jlab.detector.pulse.IExtractor.extract} + * method should know what to do, e.g., hardcoded, channel-independent parameters. + * + * FIXME: Passing the #samples around is obviously bad, and there's probably a + * few non-horrible ways that can be addressed without changing bank format. + * + * @author baltzell + */ +public abstract class HipoExtractor implements IExtractor { + + /** + * @param n number of samples in readout + * @param it CCDB table containing extraction initialization parameters + * @param event the event to modify + * @param schema bank schema factory + * @param wfBankName name of the input waveform bank + * @param adcBankName name of the output ADC bank + */ + public void update(int n, IndexedTable it, Event event, SchemaFactory schema, String wfBankName, String adcBankName) { + Bank wf = new Bank(schema.getSchema(wfBankName)); + event.read(wf); + if (wf.getRows() > 0) { + Bank adc = new Bank(schema.getSchema(adcBankName)); + update(n, it, wf, adc); + event.remove(schema.getSchema(adcBankName)); + if (adc.getRows() > 0) event.write(adc); + } + } + + /** + * This could be overriden, e.g., for non-standard ADC banks. + * @param n number of samples in readout + * @param it CCDB table containing extraction initialization parameters + * @param event the event to modify + * @param wfBankName name of the input waveform bank + * @param adcBankName name of the output ADC bank + */ + public void update(int n, IndexedTable it, DataEvent event, String wfBankName, String adcBankName) { + DataBank wf = event.getBank(wfBankName); + if (wf.rows() > 0) { + event.removeBank(adcBankName); + List pulses = getPulses(n, it, wf); + if (pulses != null && !pulses.isEmpty()) { + DataBank adc = event.createBank(adcBankName, pulses.size()); + for (int i=0; i 0) { + List pulses = getPulses(n, it, wfBank); + adcBank.reset(); + adcBank.setRows(pulses!=null ? pulses.size() : 0); + if (pulses!=null && !pulses.isEmpty()) { + for (int i=0; i getPulses(int n, IndexedTable it, DataBank wfBank) { + List pulses = null; + short[] samples = new short[n]; + for (int i=0; i p = it==null ? extract(null, i, samples) : + extract(it.getNamedEntry(getIndices(wfBank,i)), i, samples); + if (p!=null && !p.isEmpty()) { + if (pulses == null) pulses = new ArrayList<>(); + pulses.addAll(p); + } + } + return pulses; + } + + private List getPulses(int n, IndexedTable it, Bank wfBank) { + List pulses = null; + short[] samples = new short[n]; + for (int i=0; i(); + pulses.addAll(p); + } + } + return pulses; + } + +} \ No newline at end of file diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/IExtractor.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/IExtractor.java new file mode 100644 index 000000000..bfa765d84 --- /dev/null +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/IExtractor.java @@ -0,0 +1,10 @@ +package org.jlab.detector.pulse; + +import java.util.List; +import org.jlab.utils.groups.NamedEntry; + +public interface IExtractor { + + public List extract(NamedEntry pars, int id, short... samples); + +} diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Mode3.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Mode3.java new file mode 100644 index 000000000..bf4049f60 --- /dev/null +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Mode3.java @@ -0,0 +1,65 @@ +package org.jlab.detector.pulse; + +import java.util.ArrayList; +import java.util.List; +import org.jlab.utils.groups.NamedEntry; + +/** + * Similar to a FADC250 Mode-3 pulse extraction. + * + * @author baltzell + */ +public class Mode3 extends HipoExtractor { + + // Fixed extraction parameters: + final double ped = 2000; + final double tet = 2000; + final int nsa = 30; + final int nsb = 5; + + /** + * @param pars CCDB row + * @param id link to row in source bank + * @param samples ADC samples + * @return extracted pulses + */ + @Override + public List extract(NamedEntry pars, int id, short... samples) { + + List pulses = null; + + /* + // Retrive extraction parameters from a CCDB table: + double ped = pars.getValue("ped").doubleValue(); + double tet = pars.getValue("tet").doubleValue(); + int nsa = pars.getValue("nsa").intValue(); + int nsb = pars.getValue("nsb").intValue(); + */ + + // Perform the extraction: + for (int i=0; i ped+tet && samples[i+1] > samples[i]) { + int n = 0; + float integral = 0; + // Integrate the pulse: + for (int j=i-nsb; j<=i+nsa; ++j) { + if (j<0) continue; + if (j>=samples.length) break; + integral += samples[j]; + n++; + } + integral -= n * ped; + Pulse p = new Pulse(integral, i, 0x0, id); + p.pedestal = (float)(ped); + // Add the new pulse to the list: + if (pulses == null) pulses = new ArrayList<>(); + pulses.add(p); + // Add a holdoff time before next possible pulse: + i += nsa; + } + } + return pulses; + } + +} diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Mode7.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Mode7.java new file mode 100644 index 000000000..b7f58f2c7 --- /dev/null +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Mode7.java @@ -0,0 +1,39 @@ +package org.jlab.detector.pulse; + +import java.util.List; +import org.jlab.utils.groups.NamedEntry; + +/** + * Similar to a Mode-7 FADC250 pulse extraction. + * + * @author baltzell + */ +public class Mode7 extends Mode3 { + + /** + * @param t0 threshold-crossing sample index + * @param ped pedestal (for calculating pulse half-height) + * @param samples + * @return pulse time + */ + private static float calculateTime(int t0, float ped, short... samples) { + for (int j=t0+1; j extract(NamedEntry pars, int id, short... samples) { + List pulses = super.extract(pars, id, samples); + for (Pulse p : pulses) + p.time = calculateTime((int)p.time, p.pedestal, samples); + return pulses; + } + +} diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/ModeAHDC.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/ModeAHDC.java new file mode 100644 index 000000000..2d4aa03c0 --- /dev/null +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/ModeAHDC.java @@ -0,0 +1,257 @@ +package org.jlab.detector.pulse; + +import java.util.List; +import java.util.ArrayList; + +import net.jcip.annotations.GuardedBy; +import org.jlab.utils.groups.NamedEntry; + + +/** + * A new extraction method dedicated to the AHDC signal waveform + * + * Some blocks of code are inspired by MVTFitter.java + * + * @author ftouchte + */ +public class ModeAHDC extends HipoExtractor { + + public static final short ADC_LIMIT = 4095; // 2^12-1 + /** + * This method extracts relevant informations from the digitized signal + * (the samples) and store them in a Pulse + * + * @param pars CCDB row + * @param id link to row in source bank + * @param samples ADC samples + */ + @Override + public List extract(NamedEntry pars, int id, short... samples){ + // Settings parameters (they can be initialised by a CCDB) + float samplingTime = 0; + int sparseSample = 0; + short adcOffset = 0; + long timeStamp = 0; + float fineTimeStampResolution = 0; + + float amplitudeFractionCFA = 0; + int binDelayCFD = 0; + float fractionCFD = 0; + + // Calculation intermediaries + int binMax = 0; //Bin of the max ADC over the pulse + int binOffset = 0; //Offset due to sparse sample + float adcMax = 0; //Max value of ADC over the pulse (fitted) + float timeMax =0; //Time of the max ADC over the pulse (fitted) + float integral = 0; //Sum of ADCs over the pulse (not fitted) + long timestamp = 0; + + short[] samplesCorr; //Waveform after offset (pedestal) correction + int binNumber = 0; //Number of bins in one waveform + + float timeRiseCFA = 0; // moment when the signal reaches a Constant Fraction of its Amplitude uphill (fitted) + float timeFallCFA = 0; // moment when the signal reaches a Constant Fraction of its Amplitude downhill (fitted) + float timeOverThresholdCFA = 0; // is equal to (timeFallCFA - timeRiseCFA) + float timeCFD =0 ; // time extracted using the Constant Fraction Discriminator (CFD) algorithm (fitted) + /// ///////////////////////// + // Begin waveform correction + /// //////////////////////// + //waveformCorrection(samples,adcOffset,samplingTime,sparseSample, binMax, adcMax, integral, samplesCorr[], binOffset, timeMax); + /** + * This method subtracts the pedestal (noise) from samples and stores it in : samplesCorr + * It also computes a first value for : adcMax, binMax, timeMax and integral + * This code is inspired by the one of MVTFitter.java + * @param samples ADC samples + * @param adcOffset pedestal or noise level + * @param samplingTime time between two adc bins + * @param sparseSample used to define binOffset + */ + //private void waveformCorrection(short[] samples, short adcOffset, float samplingTime, int sparseSample, int binMax, int adcMax, int integral, short samplesCorr[], int binOffset, int timeMax){ + binNumber = samples.length; + binMax = 0; + adcMax = (short) (samples[0] - adcOffset); + integral = 0; + samplesCorr = new short[binNumber]; + for (int bin = 0; bin < binNumber; bin++){ + samplesCorr[bin] = (short) (samples[bin] - adcOffset); + if (adcMax < samplesCorr[bin]){ + adcMax = samplesCorr[bin]; + binMax = bin; + } + integral += samplesCorr[bin]; + } + /* + * If adcMax + adcOffset == ADC_LIMIT, that means there is saturation + * In that case, binMax is the middle of the first plateau + * This convention can be changed + */ + if ((short) adcMax + adcOffset == ADC_LIMIT) { + int binMax2 = binMax; + for (int bin = binMax; bin < binNumber; bin++){ + if (samplesCorr[bin] + adcOffset == ADC_LIMIT) { + binMax2 = bin; + } + else { + break; + } + } + binMax = (binMax + binMax2)/2; + } + binOffset = sparseSample*binMax; + timeMax = (binMax + binOffset)*samplingTime; + //} + + /// ///////////////////////// + // Begin fit average + /// //////////////////////// + + //fitAverage(samplingTime); + /** + * This method gives a more precise value of the max of the waveform by computing the average of five points around the binMax + * It is an alternative to fitParabolic() + * The suitability of one of these fits can be the subject of a study + * Remark : This method updates adcMax but doesn't change timeMax + * @param samplingTime time between 2 ADC bins + */ + //private void fitAverage(float samplingTime){ + if ((binMax - 2 >= 0) && (binMax + 2 <= binNumber - 1)){ + adcMax = 0; + for (int bin = binMax - 2; bin <= binMax + 2; bin++){ + adcMax += samplesCorr[bin]; + } + adcMax = adcMax/5; + } + //} + + /// ///////////////////////// + // Begin computeTimeAtConstantFractionAmplitude + /// //////////////////////// + //computeTimeAtConstantFractionAmplitude(samplingTime,amplitudeFractionCFA); + /** + * This method determines the moment when the signal reaches a Constant Fraction of its Amplitude (i.e fraction*adcMax) + * It fills the attributs : timeRiseCFA, timeFallCFA, timeOverThresholdCFA + * @param samplingTime time between 2 ADC bins + * @param amplitudeFraction amplitude fraction between 0 and 1 + */ + //private void computeTimeAtConstantFractionAmplitude(float samplingTime, float amplitudeFractionCFA){ + float threshold = amplitudeFractionCFA*adcMax; + // timeRiseCFA + int binRise = 0; + for (int bin = 0; bin < binMax; bin++){ + if (samplesCorr[bin] < threshold) + binRise = bin; // last pass below threshold and before adcMax + } // at this stage : binRise < timeRiseCFA/samplingTime <= binRise + 1 // timeRiseCFA is determined by assuming a linear fit between binRise and binRise + 1 + float slopeRise = 0; + if (binRise + 1 <= binNumber-1) + slopeRise = samplesCorr[binRise+1] - samplesCorr[binRise]; + float fittedBinRise = (slopeRise == 0) ? binRise : binRise + (threshold - samplesCorr[binRise])/slopeRise; + timeRiseCFA = (fittedBinRise + binOffset)*samplingTime; // binOffset is determined in wavefromCorrection() // must be the same for all time ? // or must be defined using fittedBinRise*sparseSample + + // timeFallCFA + int binFall = binMax; + for (int bin = binMax; bin < binNumber; bin++){ + if (samplesCorr[bin] > threshold){ + binFall = bin; + } + else { + binFall = bin; + break; // first pass below the threshold + } + } // at this stage : binFall - 1 <= timeRiseCFA/samplingTime < binFall // timeFallCFA is determined by assuming a linear fit between binFall - 1 and binFall + float slopeFall = 0; + if (binFall - 1 >= 0) + slopeFall = samplesCorr[binFall] - samplesCorr[binFall-1]; + float fittedBinFall = (slopeFall == 0) ? binFall : binFall-1 + (threshold - samplesCorr[binFall-1])/slopeFall; + timeFallCFA = (fittedBinFall + binOffset)*samplingTime; + + // timeOverThreshold + timeOverThresholdCFA = timeFallCFA - timeRiseCFA; + //} + /// ///////////////////////// + // Begin computeTimeUsingConstantFractionDiscriminator + /// //////////////////////// + //computeTimeUsingConstantFractionDiscriminator(samplingTime,fractionCFD,binDelayCFD); + /** + * This methods extracts a time using the Constant Fraction Discriminator (CFD) algorithm + * It fills the attribut : timeCFD + * @param samplingTime time between 2 ADC bins + * @param fractionCFD CFD fraction parameter between 0 and 1 + * @param binDelayCFD CFD delay parameter + */ + //private void computeTimeUsingConstantFractionDiscriminator(float samplingTime, float fractionCFD, int binDelayCFD){ + float[] signal = new float[binNumber]; + // signal generation + for (int bin = 0; bin < binNumber; bin++){ + signal[bin] = (1 - fractionCFD)*samplesCorr[bin]; // we fill it with a fraction of the original signal + if (bin < binNumber - binDelayCFD) + signal[bin] += -1*fractionCFD*samplesCorr[bin + binDelayCFD]; // we advance and invert a complementary fraction of the original signal and superimpose it to the previous signal + } + // determine the two humps + int binHumpSup = 0; + int binHumpInf = 0; + for (int bin = 0; bin < binNumber; bin++){ + if (signal[bin] > signal[binHumpSup]) + binHumpSup = bin; + } + for (int bin = 0; bin < binHumpSup; bin++){ // this loop has been added to be sure : binHumpInf < binHumpSup + if (signal[bin] < signal[binHumpInf]) + binHumpInf = bin; + } + // research for zero + int binZero = 0; + for (int bin = binHumpInf; bin <= binHumpSup; bin++){ + if (signal[bin] < 0) + binZero = bin; // last pass below zero + } // at this stage : binZero < timeCFD/samplingTime <= binZero + 1 // timeCFD is determined by assuming a linear fit between binZero and binZero + 1 + float slopeCFD = 0; + if (binZero + 1 <= binNumber) + slopeCFD = signal[binZero+1] - signal[binZero]; + float fittedBinZero = (slopeCFD == 0) ? binZero : binZero + (0 - signal[binZero])/slopeCFD; + timeCFD = (fittedBinZero + binOffset)*samplingTime; + + //} + + /// ///////////////////////// + // Begin fineTimeStampCorrection + /// //////////////////////// + //fineTimeStampCorrection(timeStamp,fineTimeStampResolution); + /** + * From MVTFitter.java + * Make fine timestamp correction (using dream (=electronic chip) clock) + * @param timeStamp timing informations (used to make fine corrections) + * @param fineTimeStampResolution precision of dream clock (usually 8) + */ + //private void fineTimeStampCorrection (long timeStamp, float fineTimeStampResolution) { + //this.timestamp = timeStamp; + String binaryTimeStamp = Long.toBinaryString(timeStamp); //get 64 bit timestamp in binary format + if (binaryTimeStamp.length()>=3){ + byte fineTimeStamp = Byte.parseByte(binaryTimeStamp.substring(binaryTimeStamp.length()-3,binaryTimeStamp.length()),2); //fineTimeStamp : keep and convert last 3 bits of binary timestamp + timeMax += (float) ((fineTimeStamp+0.5) * fineTimeStampResolution); //fineTimeStampCorrection + // Question : I wonder if I have to do the same thing of all time quantities that the extract() methods compute. + } + //} + // output + Pulse pulse = new Pulse(); + pulse.adcMax = adcMax; + pulse.time = timeMax; + pulse.timestamp = timestamp; + pulse.integral = integral; + pulse.timeRiseCFA = timeRiseCFA; + pulse.timeFallCFA = timeFallCFA; + pulse.timeOverThresholdCFA = timeOverThresholdCFA; + pulse.timeCFD = timeCFD; + //pulse.binMax = binMax; + //pulse.binOffset = binOffset; + pulse.pedestal = adcOffset; + List output = new ArrayList<>(); + output.add(pulse); + return output; + } + /** + * Fit the max of the pulse using parabolic fit, this method updates the timeMax and adcMax values + * @param samplingTime time between 2 ADC bins + */ + private void fitParabolic(float samplingTime) { + + } +} diff --git a/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Pulse.java b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Pulse.java new file mode 100644 index 000000000..9874b3ad7 --- /dev/null +++ b/common-tools/clas-detector/src/main/java/org/jlab/detector/pulse/Pulse.java @@ -0,0 +1,46 @@ +package org.jlab.detector.pulse; + +import org.jlab.detector.base.DetectorDescriptor; + +/** + * Just a dumb data container + */ +public class Pulse { + + public DetectorDescriptor descriptor; + public long timestamp; + public float integral; + public float time; + public float pedestal; + public long flags; + public int id; + + public float adcMax; + public float timeRiseCFA; + public float timeFallCFA; + public float timeOverThresholdCFA; + public float timeCFD; + + /** + * Units are the same as the raw units of the samples. + * @param integral pulse integral, pedestal-subtracted + * @param time pulse time + * @param flags user flags + * @param id link to row in source bank + */ + public Pulse(float integral, float time, long flags, int id) { + this.integral = integral; + this.time = time; + this.flags = flags; + this.id = id; + } + + public Pulse(){} + + @Override + public String toString() { + return String.format("pulse: integral=%f time=%f flags=%d id=%d", + integral, time, flags, id); + } + +} diff --git a/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/AHDC/AlertDCFactory.java b/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/AHDC/AlertDCFactory.java index 967914170..0a9bd1302 100644 --- a/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/AHDC/AlertDCFactory.java +++ b/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/AHDC/AlertDCFactory.java @@ -1,8 +1,3 @@ -/* - * To change this license header, choose License Headers in Project Properties. - * To change this template file, choose Tools | Templates - * and open the template in the editor. - */ package org.jlab.geom.detector.alert.AHDC; import org.jlab.geom.base.ConstantProvider; @@ -59,8 +54,7 @@ public AlertDCDetector createDetectorLocal(ConstantProvider cp) { @Override public AlertDCSector createSector(ConstantProvider cp, int sectorId) { - if (!(0 <= sectorId && sectorId < nsectors)) throw new IllegalArgumentException("Error: invalid sector=" + sectorId); - AlertDCSector sector = new AlertDCSector(sectorId); + AlertDCSector sector = new AlertDCSector(sectorId+1); for (int superlayerId = 0; superlayerId < nsuperl; superlayerId++) sector.addSuperlayer(createSuperlayer(cp, sectorId, superlayerId)); return sector; @@ -69,9 +63,7 @@ public AlertDCSector createSector(ConstantProvider cp, int sectorId) { @Override public AlertDCSuperlayer createSuperlayer(ConstantProvider cp, int sectorId, int superlayerId) { - if (!(0 <= sectorId && sectorId < nsectors)) throw new IllegalArgumentException("Error: invalid sector=" + sectorId); - if (!(0 <= superlayerId && superlayerId < nsuperl)) throw new IllegalArgumentException("Error: invalid superlayer=" + superlayerId); - AlertDCSuperlayer superlayer = new AlertDCSuperlayer(sectorId, superlayerId); + AlertDCSuperlayer superlayer = new AlertDCSuperlayer(sectorId+1, superlayerId+1); for (int layerId = 0; layerId < nlayers; layerId++) superlayer.addLayer(createLayer(cp, sectorId, superlayerId, layerId)); @@ -85,7 +77,7 @@ public AlertDCLayer createLayer(ConstantProvider cp, int sectorId, int superlaye if (!(0 <= superlayerId && superlayerId < nsuperl)) throw new IllegalArgumentException("Error: invalid superlayer=" + superlayerId); if (!(0 <= layerId && layerId < nlayers)) throw new IllegalArgumentException("Error: invalid layer=" + layerId); - AlertDCLayer layer = new AlertDCLayer(sectorId, superlayerId, layerId); + AlertDCLayer layer = new AlertDCLayer(sectorId+1, superlayerId+1, layerId+1); // Load constants AHDC // Length in Z mm! @@ -221,7 +213,7 @@ public AlertDCLayer createLayer(ConstantProvider cp, int sectorId, int superlaye // not possible to add directly PrismaticComponent class because it is an ABSTRACT // a new class should be created: public class NewClassWire extends PrismaticComponent {...} // 5 top points & 5 bottom points with convexe shape. Concave shape is not supported. - AlertDCWire wire = new AlertDCWire(wireId, wireLine, firstF, secondF); + AlertDCWire wire = new AlertDCWire(wireId+1, wireLine, firstF, secondF); // Add wire object to the list layer.addComponent(wire); } diff --git a/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/ATOF/AlertTOFFactory.java b/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/ATOF/AlertTOFFactory.java index 4057a9b67..ba4cd181b 100644 --- a/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/ATOF/AlertTOFFactory.java +++ b/common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/ATOF/AlertTOFFactory.java @@ -1,11 +1,3 @@ -/* - * To change this license header, choose License Headers in Project Properties. - * To change this template file, choose Tools | Templates - * and open the template in the editor. - */ - -//package clas12vis; - package org.jlab.geom.detector.alert.ATOF; import org.jlab.geom.base.ConstantProvider; @@ -16,9 +8,6 @@ import org.jlab.geom.prim.Point3D; import org.jlab.geom.prim.Transformation3D; -import java.util.ArrayList; -import java.util.List; - /** * @author viktoriya * this is the latest ATOF geometry class to be used in reco. and in GEMC simulations! @@ -60,7 +49,6 @@ public AlertTOFDetector createDetectorLocal(ConstantProvider cp) { @Override public AlertTOFSector createSector(ConstantProvider cp, int sectorId) { - if (!(0 <= sectorId && sectorId < nsectors)) throw new IllegalArgumentException("Error: invalid sector=" + sectorId); AlertTOFSector sector = new AlertTOFSector(sectorId); for (int superlayerId = 0; superlayerId < nsuperl; superlayerId++) sector.addSuperlayer(createSuperlayer(cp, sectorId, superlayerId)); @@ -69,8 +57,6 @@ public AlertTOFSector createSector(ConstantProvider cp, int sectorId) { @Override public AlertTOFSuperlayer createSuperlayer(ConstantProvider cp, int sectorId, int superlayerId) { - if (!(0 <= sectorId && sectorId < nsectors)) throw new IllegalArgumentException("Error: invalid sector=" + sectorId); - if (!(0 <= superlayerId && superlayerId < nsuperl)) throw new IllegalArgumentException("Error: invalid superlayer=" + superlayerId); AlertTOFSuperlayer superlayer = new AlertTOFSuperlayer(sectorId, superlayerId); if (superlayerId == 0) { @@ -110,8 +96,6 @@ public AlertTOFLayer createLayer(ConstantProvider cp, int sectorId, int superlay AlertTOFLayer layer = new AlertTOFLayer(sectorId, superlayerId, layerId); - List planes = new ArrayList<>(); - double len_b = layerId * pad_z + layerId * gap_pad_z; // back paddle plan double len_f = len_b + pad_z; // front paddle plan double Rl = R0; @@ -136,7 +120,7 @@ public AlertTOFLayer createLayer(ConstantProvider cp, int sectorId, int superlay Point3D p5 = new Point3D(dR / 2, -widthTl / 2, len_b); Point3D p6 = new Point3D(dR / 2, widthTl / 2, len_b); Point3D p7 = new Point3D(-dR / 2, widthBl / 2, len_b); - ScintillatorPaddle Paddle = new ScintillatorPaddle(sectorId * 4 + padId, p0, p1, p2, p3, p4, p5, p6, p7); + ScintillatorPaddle Paddle = new ScintillatorPaddle(sectorId * 4 + padId+1, p0, p1, p2, p3, p4, p5, p6, p7); double openAng_sector_deg = npaddles * openAng_pad_deg; Paddle.rotateZ(Math.toRadians(padId * openAng_pad_deg + sectorId * openAng_sector_deg)); @@ -155,7 +139,6 @@ public AlertTOFLayer createLayer(ConstantProvider cp, int sectorId, int superlay Plane3D plane = new Plane3D(0, Rl, 0, 0, 1, 0); plane.rotateZ(sectorId * openAng_sector_rad - Math.toRadians(90)); - planes.add(plane); return layer; } diff --git a/common-tools/clas-reco/src/main/java/org/jlab/clas/reco/ReconstructionEngine.java b/common-tools/clas-reco/src/main/java/org/jlab/clas/reco/ReconstructionEngine.java index 79b4fc6f1..cce8c4301 100644 --- a/common-tools/clas-reco/src/main/java/org/jlab/clas/reco/ReconstructionEngine.java +++ b/common-tools/clas-reco/src/main/java/org/jlab/clas/reco/ReconstructionEngine.java @@ -96,6 +96,10 @@ public void registerOutputBank(String... bankName) { } } + protected SchemaFactory getSchemaFactory() { + return this.engineDictionary; + } + protected RawBank getRawBankReader(String bankName) { return new RawDataBank(bankName, this.rawBankOrders); } diff --git a/common-tools/clas-reco/src/main/java/org/jlab/clas/service/PulseExtractorEngine.java b/common-tools/clas-reco/src/main/java/org/jlab/clas/service/PulseExtractorEngine.java new file mode 100644 index 000000000..caf6702f2 --- /dev/null +++ b/common-tools/clas-reco/src/main/java/org/jlab/clas/service/PulseExtractorEngine.java @@ -0,0 +1,50 @@ +package org.jlab.clas.service; + +import org.jlab.clas.reco.ReconstructionEngine; +import org.jlab.detector.pulse.Mode3; +import org.jlab.detector.pulse.Mode7; +import org.jlab.io.base.DataEvent; + +/** + * An example of using a {@link org.jlab.detector.pulse.HipoExtractor} from a + * {@link org.jlab.clas.reco.ReconstructionEngine}. + * + * @author baltzell + */ +public class PulseExtractorEngine extends ReconstructionEngine { + + Mode3 mode3 = new Mode3(); + Mode3 mode7 = new Mode7(); + + public PulseExtractorEngine() { + super("PULSE", "baltzell", "0.0"); + } + + @Override + public boolean init() { + // If using a CCDB table, must register it here: + //requireConstants("/daq/config/ahdc"); + return true; + } + + @Override + public boolean processDataEvent(DataEvent event) { + + // No CCDB table, hardcoded parameters in the extractor: + mode3.update(6, null, event, "BMT::wf", "BMT::adc"); + //mode7.update(80, null, event, "AHDC::wf", "AHDC::adc"); + + /* + // Requiring a CCDB table: + DataBank runConfig = event.getBank("RUN::config"); + if (runConfig.rows()>0) { + IndexedTable it = getConstantsManager().getConstants( + runConfig.getInt("run", 0), "/daq/config/ahdc"); + basic.update(136, it, event, "AHDC::wf", "AHDC::adc"); + } + */ + + return true; + } + +} diff --git a/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/IndexedTable.java b/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/IndexedTable.java index 2c4bb9a5d..a0865613f 100644 --- a/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/IndexedTable.java +++ b/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/IndexedTable.java @@ -65,7 +65,7 @@ public void setPrecision(Integer precision){ str.append("f"); this.precisionFormat = str.toString(); } - + public boolean hasEntry(int... index){ return this.entries.hasItem(index); } @@ -145,7 +145,11 @@ public double getDoubleValue(String item, int... index){ } return 0; } - + + public NamedEntry getNamedEntry(int... index) { + return NamedEntry.create(entries.getItem(index), entryNames, index); + } + public IndexedList getList(){ return this.entries; } @@ -295,7 +299,7 @@ public Object getValueAt(int row, int column) { } return trow.getValue(column-ic).toString(); } - + public class RowConstraint { public int COLUMN = 0; @@ -377,4 +381,5 @@ public void setSize(int size){ } } } + } diff --git a/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/NamedEntry.java b/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/NamedEntry.java new file mode 100644 index 000000000..b221ad99e --- /dev/null +++ b/common-tools/clas-utils/src/main/java/org/jlab/utils/groups/NamedEntry.java @@ -0,0 +1,34 @@ +package org.jlab.utils.groups; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.jlab.utils.groups.IndexedTable.IndexedEntry; + +/** + * IndexedEntry wrapper for names and indices. + */ +public class NamedEntry { + + IndexedEntry entry; + Map names = new HashMap<>(); + int[] index; + + public static NamedEntry create(IndexedEntry entry, List names, int... index) { + NamedEntry e = new NamedEntry(); + for (int i=0; i0.8.12 + + ai.djl + model-zoo + 0.30.0 + compile + + + ai.djl.pytorch + pytorch-model-zoo + 0.30.0 + + diff --git a/reconstruction/alert/pom.xml b/reconstruction/alert/pom.xml index 4d7f9502a..0029fb4bf 100644 --- a/reconstruction/alert/pom.xml +++ b/reconstruction/alert/pom.xml @@ -40,6 +40,17 @@ 11.0.5-SNAPSHOT compile + + ai.djl + model-zoo + 0.30.0 + compile + + + ai.djl.pytorch + pytorch-model-zoo + 0.30.0 + diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/AIPrediction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/AIPrediction.java new file mode 100644 index 000000000..1e800c6cf --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/AIPrediction.java @@ -0,0 +1,37 @@ +package org.jlab.rec.ahdc.AI; + +import java.util.ArrayList; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; + +import java.io.IOException; + +public class AIPrediction { + + + public AIPrediction() throws ModelNotFoundException, MalformedModelException, IOException { + } + + public ArrayList prediction(ArrayList> tracks, ZooModel model) throws TranslateException { + ArrayList result = new ArrayList<>(); + for (ArrayList track : tracks) { + float[] a = new float[]{(float) track.get(0).getX(), (float) track.get(0).getY(), + (float) track.get(1).getX(), (float) track.get(1).getY(), + (float) track.get(2).getX(), (float) track.get(2).getY(), + (float) track.get(3).getX(), (float) track.get(3).getY(), + (float) track.get(4).getX(), (float) track.get(4).getY(), + }; + + Predictor my_predictor = model.newPredictor(); + result.add(new TrackPrediction(my_predictor.predict(a), track)); + } + + return result; + } + + +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java new file mode 100644 index 000000000..17cb3f365 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreClustering.java @@ -0,0 +1,110 @@ +package org.jlab.rec.ahdc.AI; + +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +public class PreClustering { + + private ArrayList fill(List hits, int super_layer, int layer) { + + ArrayList result = new ArrayList<>(); + for (Hit hit : hits) { + if (hit.getSuperLayerId() == super_layer && hit.getLayerId() == layer) result.add(hit); + } + return result; + } + + public ArrayList find_preclusters_for_AI(List AHDC_hits) { + ArrayList preclusters = new ArrayList<>(); + + ArrayList s1l1 = fill(AHDC_hits, 1, 1); + ArrayList s2l1 = fill(AHDC_hits, 2, 1); + ArrayList s2l2 = fill(AHDC_hits, 2, 2); + ArrayList s3l1 = fill(AHDC_hits, 3, 1); + ArrayList s3l2 = fill(AHDC_hits, 3, 2); + ArrayList s4l1 = fill(AHDC_hits, 4, 1); + ArrayList s4l2 = fill(AHDC_hits, 4, 2); + ArrayList s5l1 = fill(AHDC_hits, 5, 1); + + // Sort hits of each layers by phi: + s1l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s2l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s2l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s3l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s3l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s4l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s4l2.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + s5l1.sort(new Comparator() {@Override public int compare(Hit a1, Hit a2) {return Double.compare(a1.getPhi(), a2.getPhi());}}); + + ArrayList> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1)); + + for (ArrayList p : all_super_layer) { + for (Hit hit : p) { + hit.setUse(false); + } + } + + for (ArrayList p : all_super_layer) { + for (Hit hit : p) { + if (hit.is_NoUsed()) { + ArrayList temp = new ArrayList<>(); + temp.add(hit); + hit.setUse(true); + + boolean has_next = true; + while (has_next) { + has_next = false; + for (Hit hit1 : p) { + if (hit1.is_NoUsed() && (hit1.getWireId() == temp.get(temp.size() - 1).getWireId() + 1 || hit1.getWireId() == temp.get(temp.size() - 1).getWireId() - 1)) { + temp.add(hit1); + hit1.setUse(true); + has_next = true; + break; + } + } + } + if (!temp.isEmpty()) preclusters.add(new PreCluster(temp)); + } + } + } + return preclusters; + } + + public ArrayList merge_preclusters(ArrayList preclusters) { + double distance_max = 8.0; + + ArrayList superpreclusters = new ArrayList<>(); + for (PreCluster precluster : preclusters) { + if (!precluster.is_Used()) { + ArrayList tmp = new ArrayList<>(); + tmp.add(precluster); + precluster.set_Used(true); + for (PreCluster other : preclusters) { + if (precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getSuperLayerId() == other.get_hits_list().get(other.get_hits_list().size() - 1).getSuperLayerId() && precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getLayerId() != other.get_hits_list().get(other.get_hits_list().size() - 1).getLayerId() && !other.is_Used()) { + double dx = precluster.get_X() - other.get_X(); + double dy = precluster.get_Y() - other.get_Y(); + double distance = Math.sqrt(dx * dx + dy * dy); + + if (distance < distance_max) { + other.set_Used(true); + tmp.add(other); + } + } + } + + if (!tmp.isEmpty()) superpreclusters.add(new PreclusterSuperlayer(tmp)); + } + } + + return superpreclusters; + } + + + + +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java new file mode 100644 index 000000000..ecab32728 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/PreclusterSuperlayer.java @@ -0,0 +1,46 @@ +package org.jlab.rec.ahdc.AI; + +import org.jlab.rec.ahdc.Hit.Hit; +import org.jlab.rec.ahdc.PreCluster.PreCluster; + +import java.util.ArrayList; + +public class PreclusterSuperlayer { + private final double x; + private final double y; + private ArrayList preclusters = new ArrayList<>(); + + + ; public PreclusterSuperlayer(ArrayList preclusters_) { + this.preclusters = preclusters_; + double x_ = 0; + double y_ = 0; + + for (PreCluster p : this.preclusters) { + x_ += p.get_X(); + y_ += p.get_Y(); + } + this.x = x_ / this.preclusters.size(); + this.y = y_ / this.preclusters.size(); + + + + } + + public ArrayList getPreclusters() { + return preclusters; + } + + public double getX() { + return x; + } + + public double getY() { + return y; + } + + + public String toString() { + return "PreCluster{" + "X: " + this.x + " Y: " + this.y + " phi: " + Math.atan2(this.y, this.x) + "}\n"; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java new file mode 100644 index 000000000..dce541400 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackConstruction.java @@ -0,0 +1,106 @@ +package org.jlab.rec.ahdc.AI; + +import org.jlab.rec.ahdc.Hit.Hit; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.*; + +public class TrackConstruction { + public TrackConstruction() {} + + private double mod(double x, double y) { + + if (0. == y) return x; + + double m = x - y * Math.floor(x / y); + // handle boundary cases resulted from floating-point cut off: + if (y > 0) { // modulo range: [0..y) + if (m >= y) return 0; // Mod(-1e-16 , 360. ): m= 360. + if (m < 0) { + if (y + m == y) return 0; // just in case... + else return y + m; // Mod(106.81415022205296 , _TWO_PI ): m= -1.421e-14 + } + } else { // modulo range: (y..0] + if (m <= y) return 0; // Mod(1e-16 , -360. ): m= -360. + if (m > 0) { + if (y + m == y) return 0; // just in case... + else return y + m; // Mod(-106.81415022205296, -_TWO_PI): m= 1.421e-14 + } + } + + return m; + } + + private double warp_zero_two_pi(double angle) { return mod(angle, 2. * Math.PI); } + + private boolean angle_in_range(double angle, double lower, double upper) { return warp_zero_two_pi(angle - lower) <= warp_zero_two_pi(upper - lower); } + + + public ArrayList> get_all_possible_track(ArrayList preclusterSuperlayers) { + + // Get seeds to start the track finding algorithm + ArrayList seeds = new ArrayList<>(); + for (PreclusterSuperlayer precluster : preclusterSuperlayers) { + if (precluster.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() == 1) seeds.add(precluster); + } + seeds.sort(new Comparator() { + @Override + public int compare(PreclusterSuperlayer a1, PreclusterSuperlayer a2) { + return Double.compare(Math.atan2(a1.getY(), a1.getX()), Math.atan2(a2.getY(), a2.getX())); + } + }); + // System.out.println("seeds: " + seeds); + + // Get all possible tracks ---------------------------------------------------------------- + double max_angle = Math.toRadians(60); + + ArrayList> all_combinations = new ArrayList<>(); + for (PreclusterSuperlayer seed : seeds) { + double phi_seed = warp_zero_two_pi(Math.atan2(seed.getY(), seed.getX())); + + ArrayList track = new ArrayList<>(); + for (PreclusterSuperlayer p : preclusterSuperlayers) { + double phi_p = warp_zero_two_pi(Math.atan2(p.getY(), p.getX())); + if (angle_in_range(phi_p, phi_seed - max_angle, phi_seed + max_angle)) track.add(p); + } + // System.out.println("track: " + track.size()); + + ArrayList> combinations = new ArrayList<>(List.of(new ArrayList<>(List.of(seed)))); + // System.out.println("combinations: " + combinations); + + for (int i = 1; i < 5; ++i) { + ArrayList> new_combinations = new ArrayList<>(); + for (ArrayList combination : combinations) { + + for (PreclusterSuperlayer precluster : track) { + if (precluster.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() == seed.getPreclusters().get(0).get_hits_list().get(0).getSuperLayerId() + i) { + // System.out.printf("Good Precluster x: %.2f, y: %.2f, r: %.2f%n", precluster.getX(), precluster.getY(), Math.hypot(precluster.getX(), precluster.getY())); + // System.out.println("combination: " + combination); + + ArrayList new_combination = new ArrayList<>(combination); + new_combination.add(precluster); + // System.out.println("new_combination: " + new_combination); + new_combinations.add(new_combination); + } + } + for (ArrayList c : new_combinations) { + // System.out.println("c.size: " + c.size() + ", c: " + c); + } + + } + combinations = new_combinations; + if (combinations.size() > 10000) break; + } + for (ArrayList combination : combinations) { + if (combination.size() == 5) { + all_combinations.add(combination); + } + } + } + + return all_combinations; + } + +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java new file mode 100644 index 000000000..7ea5bd771 --- /dev/null +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/TrackPrediction.java @@ -0,0 +1,74 @@ +package org.jlab.rec.ahdc.AI; + +import org.jlab.rec.ahdc.Cluster.Cluster; +import org.jlab.rec.ahdc.PreCluster.PreCluster; + +import java.util.ArrayList; + +public class TrackPrediction { + + private float prediction; + private final ArrayList superpreclusters; + private final ArrayList preclusters = new ArrayList<>(); + private ArrayList clusters = new ArrayList<>(); + + public TrackPrediction(float prediction, ArrayList superpreclusters_) { + this.prediction = prediction; + this.superpreclusters = superpreclusters_; + + for (PreclusterSuperlayer p : this.superpreclusters) { + if (p.getPreclusters() != null) + this.preclusters.addAll(p.getPreclusters()); + } + + // Generate the clusters + for (PreCluster p : this.preclusters) { + if (p.get_Super_layer() == 1) { + for (PreCluster other : this.preclusters) { + if (other.get_Super_layer() == 2 && other.get_Layer() == 1) + clusters.add(new Cluster(p, other)); + } + } + + if (p.get_Super_layer() == 2 && p.get_Layer() == 2) { + for (PreCluster other : this.preclusters) { + if (other.get_Super_layer() == 3 && other.get_Layer() == 1) + clusters.add(new Cluster(p, other)); + } + } + + if (p.get_Super_layer() == 3 && p.get_Layer() == 2) { + for (PreCluster other : this.preclusters) { + if (other.get_Super_layer() == 4 && other.get_Layer() == 1) + clusters.add(new Cluster(p, other)); + } + } + + if (p.get_Super_layer() == 4 && p.get_Layer() == 2) { + for (PreCluster other : this.preclusters) { + if (other.get_Super_layer() == 5 && other.get_Layer() == 1) + clusters.add(new Cluster(p, other)); + } + } + + + } + + } + + public float getPrediction() { + return prediction; + } + + public ArrayList getSuperpreclusters() { + return superpreclusters; + } + + public ArrayList getPreclusters() { + return preclusters; + } + + public ArrayList getClusters() { + return clusters; + } +} diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt new file mode 100644 index 000000000..73e665c23 Binary files /dev/null and b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/model.pt differ diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java index 30c636dc6..72b9b5f2f 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Banks/RecoBankWriter.java @@ -2,6 +2,7 @@ import org.jlab.io.base.DataBank; import org.jlab.io.base.DataEvent; +import org.jlab.rec.ahdc.AI.TrackPrediction; import org.jlab.rec.ahdc.Cluster.Cluster; import org.jlab.rec.ahdc.Hit.Hit; import org.jlab.rec.ahdc.PreCluster.PreCluster; @@ -133,4 +134,29 @@ public DataBank fillAHDCKFTrackBank(DataEvent event, ArrayList tracks) { return bank; } + + public DataBank fillAIPrediction(DataEvent event, ArrayList predictions) { + + DataBank bank = event.createBank("AHDC_AI::Prediction", predictions.size()); + + int row = 0; + + for (TrackPrediction track : predictions) { + bank.setFloat("X1", row, (float) track.getSuperpreclusters().get(0).getX()); + bank.setFloat("Y1", row, (float) track.getSuperpreclusters().get(0).getY()); + bank.setFloat("X2", row, (float) track.getSuperpreclusters().get(1).getX()); + bank.setFloat("Y2", row, (float) track.getSuperpreclusters().get(1).getY()); + bank.setFloat("X3", row, (float) track.getSuperpreclusters().get(2).getX()); + bank.setFloat("Y3", row, (float) track.getSuperpreclusters().get(2).getY()); + bank.setFloat("X4", row, (float) track.getSuperpreclusters().get(3).getX()); + bank.setFloat("Y4", row, (float) track.getSuperpreclusters().get(3).getY()); + bank.setFloat("X5", row, (float) track.getSuperpreclusters().get(4).getX()); + bank.setFloat("Y5", row, (float) track.getSuperpreclusters().get(4).getY()); + + bank.setFloat("Pred", row, track.getPrediction()); + row++; + } + + return bank; + } } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java index 6d90337e0..249716499 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/Cluster.java @@ -10,6 +10,10 @@ */ public class Cluster { + private double _StereoAngle = 20.0; + private double _DeltaZ = 300.0; + private double _Zoffset = 150.0; + private double _Radius; private double _Phi; private double _Z; @@ -27,7 +31,9 @@ public Cluster(PreCluster precluster, PreCluster other_precluster) { _PreClusters_list.add(precluster); _PreClusters_list.add(other_precluster); this._Radius = (precluster.get_Radius() + other_precluster.get_Radius()) / 2; - this._Z = ((other_precluster.get_Phi() - precluster.get_Phi()) / (Math.toRadians(20) * Math.pow(-1, precluster.get_Super_layer()) - Math.toRadians(20) * Math.pow(-1, other_precluster.get_Super_layer()))) * 300 - 150; + + this._Z = ((other_precluster.get_Phi() - precluster.get_Phi()) / (Math.toRadians(_StereoAngle) * Math.pow(-1, precluster.get_Super_layer()-1) - Math.toRadians(_StereoAngle) * Math.pow(-1, other_precluster.get_Super_layer()-1))) * _DeltaZ - _Zoffset; + double x1 = -precluster.get_Radius() * Math.sin(precluster.get_Phi()); double y1 = -precluster.get_Radius() * Math.cos(precluster.get_Phi()); double x2 = -other_precluster.get_Radius() * Math.sin(other_precluster.get_Phi()); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java index a8db6bd9f..7a1421815 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Cluster/ClusterFinder.java @@ -13,14 +13,16 @@ public class ClusterFinder { public ClusterFinder() {} private void find_associate_cluster(PreCluster precluster, List AHDC_precluster_list, int window, int minimal_distance, int super_layer, int layer, int associate_super_layer) { + //System.out.println(" precluster superlayer " + precluster.get_Super_layer() + " ref superlayer " + super_layer + " layer " + precluster.get_Layer() + " ref " + layer); if (precluster.get_Super_layer() == super_layer && precluster.get_Layer() == layer && !precluster.is_Used()) { ArrayList possible_precluster_list = new ArrayList<>(); - double phi_mean = precluster.get_Phi() + 0.1 * Math.pow(-1, precluster.get_Super_layer()); + double phi_mean = precluster.get_Phi() + 0.1 * Math.pow(-1, precluster.get_Super_layer()-1); double x = -precluster.get_Radius() * Math.sin(phi_mean); double y = -precluster.get_Radius() * Math.cos(phi_mean); for (PreCluster other_precluster : AHDC_precluster_list) { - if (other_precluster.get_Super_layer() == associate_super_layer && other_precluster.get_Layer() == 0 && !other_precluster.is_Used()) { + //System.out.println(" othercluster superlayer " + other_precluster.get_Super_layer() + " ref " + associate_super_layer + " layer " + other_precluster.get_Layer() + " ref 1 now"); + if (other_precluster.get_Super_layer() == associate_super_layer && other_precluster.get_Layer() == 1 && !other_precluster.is_Used()) { double x_start = x - window; double x_end = x + window; double y_start = y - window; @@ -67,10 +69,10 @@ public void findCluster(List AHDC_precluster_list) { // Collections.sort(AHDC_precluster_list); for (PreCluster precluster : AHDC_precluster_list) { - find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 0, 0, 1); find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 1, 1, 2); - find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 2, 1, 3); - find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 3, 1, 4); + find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 2, 2, 3); + find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 3, 2, 4); + find_associate_cluster(precluster, AHDC_precluster_list, window, minimal_distance, 4, 2, 5); } for (Cluster cluster : _list_with_maybe_same_cluster) { diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/Hit.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/Hit.java index aeb429bbe..6ee4bab6a 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/Hit.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/Hit.java @@ -3,6 +3,7 @@ public class Hit implements Comparable { + private final double thster = Math.toRadians(20.0); private final int id; private final int superLayerId; private final int layerId; @@ -31,35 +32,38 @@ private void wirePosition() { double numWires = 32.0; double R_layer = 47.0; - + switch (this.superLayerId) { - case 0: + case 1: numWires = 47.0; R_layer = 32.0; break; - case 1: + case 2: numWires = 56.0; R_layer = 38.0; break; - case 2: + case 3: numWires = 72.0; R_layer = 48.0; break; - case 3: + case 4: numWires = 87.0; R_layer = 58.0; break; - case 4: + case 5: numWires = 99.0; R_layer = 68.0; break; } - R_layer = R_layer + DR_layer * this.layerId; + R_layer = R_layer + DR_layer * (this.layerId-1); double alphaW_layer = Math.toRadians(round / (numWires)); - double wx = -R_layer * Math.sin(alphaW_layer * this.wireId); - double wy = -R_layer * Math.cos(alphaW_layer * this.wireId); - + //should it be at z = 0? in which case, we need to account for the positive or negative stereo angle... + double wx = -R_layer * Math.sin(alphaW_layer * (this.wireId-1) + 0.5*thster * (Math.pow(-1, this.superLayerId-1))); + double wy = -R_layer * Math.cos(alphaW_layer * (this.wireId-1) + 0.5*thster * (Math.pow(-1, this.superLayerId-1))); + + //System.out.println(" superlayer " + this.superLayerId + " layer " + this.layerId + " wire " + this.wireId + " R_layer " + R_layer + " wx " + wx + " wy " + wy); + this.nbOfWires = (int) numWires; this.phi = Math.atan2(wy, wx); this.radius = R_layer; @@ -124,4 +128,6 @@ public double getX() { public double getY() { return y; } + + public double getPhi() {return phi;} } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/HitReader.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/HitReader.java index 85f0ebb08..cef052afe 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/HitReader.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Hit/HitReader.java @@ -36,22 +36,6 @@ public void fetch_AHDCHits(DataEvent event) { hits.add(new Hit(id, superlayer, layer, wire, doca)); } - }else if(event.hasBank("AHDC::tdc")) { - RawDataBank bankDGTZ = new RawDataBank("AHDC::tdc"); - bankDGTZ.read(event); - //DataBank bankDGTZ = event.getBank("ALRTDC::adc"); - - - for (int i = 0; i < bankDGTZ.rows(); i++) { - int id = bankDGTZ.trueIndex(i) + 1; - int number = bankDGTZ.getByte("layer", i); - int layer = number % 10; - int superlayer = (int) (number % 100) / 10; - int wire = bankDGTZ.getShort("component", i); - double time = bankDGTZ.getInt("TDC", i)*1.0; - - hits.add(new Hit(id, superlayer, layer, wire, time)); - } } this.set_AHDCHits(hits); } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Hit.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Hit.java index a56f7b466..fa4d1dc47 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Hit.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Hit.java @@ -12,104 +12,159 @@ public class Hit implements Comparable { private final double thster = Math.toRadians(20.0); + private final double zl = 300.0;//OK private final int superLayer; private final int layer; private final int wire; private final double r; + private final double phi; private final double doca; + private final double adc; private final double numWires; private final Line3D line3D; + // Comparison with: common-tools/clas-geometry/src/main/java/org/jlab/geom/detector/alert/AHDC/AlertDCFactory.java + // here, SuperLayer, Layer, Wire, start from 1 + // in AlertDCFactory, same variables start from 1 public Hit(int superLayer, int layer, int wire, int numWire, double r, double doca) { this.superLayer = superLayer; this.layer = layer; - this.wire = wire - 1; + this.wire = wire; this.r = r; this.doca = doca; this.numWires = numWire; - - final double DR_layer = 4.0; - final double round = 360.0; - final double thster = Math.toRadians(20.0); - final double zl = 300.0; + this.adc = 0;//placeholder + + final double DR_layer = 4.0;//OK + final double round = 360.0;//OK + final double thster = Math.toRadians(20.0);//OK double numWires = 32.0; double R_layer = 47.0; - double zoff1 = 0.0d; - double zoff2 = 300.0d; + double zoff1 = -zl/2;//OK + double zoff2 = +zl/2;//OK Point3D p1 = new Point3D(R_layer, 0, zoff1); Vector3D n1 = new Vector3D(0, 0, 1); //n1.rotateY(-thopen); //n1.rotateZ(thtilt); - Plane3D lPlane = new Plane3D(p1, n1); + Plane3D lPlane = new Plane3D(p1, n1);//OK Point3D p2 = new Point3D(R_layer, 0, zoff2); Vector3D n2 = new Vector3D(0, 0, 1); //n2.rotateY(thopen); //n2.rotateZ(thtilt); - Plane3D rPlane = new Plane3D(p2, n2); + Plane3D rPlane = new Plane3D(p2, n2);//OK - switch (this.superLayer) { - case 0: + switch (this.superLayer) {//OK + case 1: numWires = 47.0; R_layer = 32.0; break; - case 1: + case 2: numWires = 56.0; R_layer = 38.0; break; - case 2: + case 3: numWires = 72.0; R_layer = 48.0; break; - case 3: + case 4: numWires = 87.0; R_layer = 58.0; break; - case 4: + case 5: numWires = 99.0; R_layer = 68.0; break; } - R_layer = R_layer + DR_layer * this.layer; - double alphaW_layer = Math.toRadians(round / (numWires)); - double wx = -R_layer * Math.sin(alphaW_layer * this.wire); - double wy = -R_layer * Math.cos(alphaW_layer * this.wire); - - double wx_end = -R_layer * Math.sin(alphaW_layer * this.wire + thster * (Math.pow(-1, this.superLayer))); - double wy_end = -R_layer * Math.cos(alphaW_layer * this.wire + thster * (Math.pow(-1, this.superLayer))); + + R_layer = R_layer + DR_layer * (this.layer-1);//OK + double alphaW_layer = Math.toRadians(round / (numWires));//OK + double wx = -R_layer * Math.sin(alphaW_layer * (this.wire-1));//OK + double wy = -R_layer * Math.cos(alphaW_layer * (this.wire-1));//OK - Line3D line = new Line3D(wx, wy, -150, wx_end, wy_end, zl/2); + double wx_end = -R_layer * Math.sin(alphaW_layer * (this.wire-1) + thster * (Math.pow(-1, this.superLayer-1)));//OK + double wy_end = -R_layer * Math.cos(alphaW_layer * (this.wire-1) + thster * (Math.pow(-1, this.superLayer-1)));//OK + this.phi = Math.atan2( (wy+wy_end)*0.5, (wx+wx_end)*0.5 ); + //System.out.println(" superlayer " + this.superLayer + " layer " + this.layer + " wire " + this.wire + " wx " + wx + " wy " + wy + " wx_end " + wx_end + " wy_end " + wy_end + " phi " + this.phi); + + Line3D line = new Line3D(wx, wy, -zl/2, wx_end, wy_end, zl/2); Point3D lPoint = new Point3D(); Point3D rPoint = new Point3D(); lPlane.intersection(line, lPoint); rPlane.intersection(line, rPoint); + //lPoint.setZ(-zl/2); + //rPoint.setZ(zl/2); + //lPoint.show(); + //rPoint.show(); // All wire go from left to right Line3D wireLine = new Line3D(lPoint, rPoint); - + //wireLine.show(); this.line3D = wireLine; } + + //hit measurement vector in cylindrical coordinates: r, phi, z + public RealVector get_Vector() { + // final double costhster = Math.cos(thster); + // final double sinthster = Math.cos(thster); + RealVector wire_meas = new ArrayRealVector(new double[]{this.r(), this.phi(), 0}); + // Array2DRowRealMatrix stereo_rotation = new Array2DRowRealMatrix(new double[][]{{1, 0.0, 0.0}, {0, costhster, -sinthster}, {0, sinthster, costhster}});//rotation of wire: needed? + return wire_meas;//.multiply(stereo_rotation); + } - public RealVector get_Vector() { + //hit measurement vector in 1 dimension: minimize distance - doca + public RealVector get_Vector_simple() { return new ArrayRealVector(new double[]{this.doca}); } - public RealMatrix get_MeasurementNoise() { - return new Array2DRowRealMatrix(new double[][]{{10}}); + //hit measurement vector in 1 dimension: minimize distance - doca - adds hit "sign" + public RealVector get_Vector_sign(int sign) { + // Attempt: multiply doca by sign + return new ArrayRealVector(new double[]{sign*this.doca}); } + public RealMatrix get_MeasurementNoise() { + final double costhster = Math.cos(thster); + final double sinthster = Math.cos(thster); + //dR = 0.1m dphi = pi dz = L/2 + Array2DRowRealMatrix wire_noise = new Array2DRowRealMatrix(new double[][]{{0.1, 0.0, 0.0}, {0.0, Math.atan(0.1/this.r), 0.0}, {0.0, 0.0, 150.0/costhster}});//uncertainty matrix in wire coordinates + Array2DRowRealMatrix stereo_rotation = new Array2DRowRealMatrix(new double[][]{{1, 0.0, 0.0}, {0, costhster, -sinthster}, {0, sinthster, costhster}});//rotation of wire + wire_noise.multiply(stereo_rotation); + + return wire_noise.multiply(wire_noise); + // + } + + public RealMatrix get_MeasurementNoise_simple() { + return new Array2DRowRealMatrix(new double[][]{{0.01}}); + } + public double doca() { return doca; } public double r() {return r;} + public double phi() {return phi;}//at z = 0; + + public double phi(double z) { + // double x_0 = r*sin(phi); + // double y_0 = r*cos(phi); + double x_z = r*Math.sin( phi + thster * z/(zl*0.5) * (Math.pow(-1, this.superLayer-1)) ); + double y_z = r*Math.cos( phi + thster * z/(zl*0.5) * (Math.pow(-1, this.superLayer-1)) ); + return Math.atan2(x_z, y_z); + } + public Line3D line() {return line3D;} public double distance(Point3D point3D) { + //System.out.println("Calculating distance: "); + //this.line3D.show(); + //point3D.show(); + //System.out.println(" d = " + this.line3D.distance(point3D).length()); return this.line3D.distance(point3D).length(); } @@ -152,6 +207,10 @@ public double getDoca() { return doca; } + public double getADC() { + return adc; + } + public Line3D getLine3D() { return line3D; } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KFitter.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KFitter.java index b94d9478b..3feddf5b2 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KFitter.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KFitter.java @@ -14,6 +14,7 @@ public class KFitter { public final Stepper stepper; private final Propagator propagator; public double chi2 = 0; + // masses/energies in MeV private final double electron_mass_c2 = PhysicsConstants.massElectron() * 1000; private final double proton_mass_c2 = PhysicsConstants.massProton() * 1000; @@ -57,7 +58,7 @@ public void predict(Indicator indicator) throws Exception { double dE = Math.abs(stepper.dEdx); double K = 0.000307075; - double sigma2_dE = indicator.material.getDensity() * K * indicator.material.getZoverA() / beta2 * tmax * s / 10 * (1.0 - beta2 / 2) * 1000 * 1000; + double sigma2_dE = indicator.material.getDensity() * K * indicator.material.getZoverA() / beta2 * tmax * s / 10 * (1.0 - beta2 / 2) * 1000 * 1000;//in MeV^2 double dp_prim_ddE = (E + dE) / Math.sqrt((E + dE) * (E + dE) - mass * mass); double sigma2_px = Math.pow(px / p, 2) * Math.pow(dp_prim_ddE, 2) * sigma2_dE; double sigma2_py = Math.pow(py / p, 2) * Math.pow(dp_prim_ddE, 2) * sigma2_dE; @@ -72,7 +73,9 @@ public void predict(Indicator indicator) throws Exception { } public void correct(Indicator indicator) { - RealVector z; + //System.out.println(" state before: (" + stateEstimation.getEntry(0) + ", " + stateEstimation.getEntry(1) + ", " + stateEstimation.getEntry(2) + ", " + stateEstimation.getEntry(3) + ", " + stateEstimation.getEntry(4) + ", " + stateEstimation.getEntry(5) + ");" ); + //System.out.println(" state radius before: " + Math.sqrt( Math.pow(stateEstimation.getEntry(0), 2) + Math.pow(stateEstimation.getEntry(1), 2) ) ); + RealVector z, z_plus, z_minus; RealMatrix measurementNoise; RealMatrix measurementMatrix; RealVector h; @@ -80,21 +83,30 @@ public void correct(Indicator indicator) { measurementNoise = new Array2DRowRealMatrix( new double[][]{ - {9.00, 0.0000, 0.0000}, - {0.00, 1e10, 0.0000}, - {0.00, 0.0000, 1e10} - }); - measurementMatrix = H_beam(stateEstimation); - h = h_beam(stateEstimation); - z = indicator.hit.get_Vector_beam(); + // {9.00, 0.0000, 0.0000}, + // {0.00, 1e10, 0.0000}, + // {0.00, 0.0000, 1e10} + {0.09, 0.0000, 0.0000}, + {0.00, 1.e10, 0.0000}, + {0.00, 0.0000, 1.e10} + });//3x3 + measurementMatrix = H_beam(stateEstimation);//6x3 + h = h_beam(stateEstimation);//3x1 + z = indicator.hit.get_Vector_beam();//0! } else { - measurementNoise = indicator.hit.get_MeasurementNoise(); - measurementMatrix = H(stateEstimation, indicator); - h = h(stateEstimation, indicator); - z = indicator.hit.get_Vector(); - + measurementNoise = indicator.hit.get_MeasurementNoise_simple();//1x1 + measurementMatrix = H_simple(stateEstimation, indicator);//6x1 + h = h_simple(stateEstimation, indicator);//.multiply(wire_sign_mat(indicator));//1x1 + z = indicator.hit.get_Vector_simple();//1x1 + + // measurementNoise = indicator.hit.get_MeasurementNoise();//3x3 + // measurementMatrix = H(stateEstimation, indicator);//6x3 + // h = h(stateEstimation, indicator);//3x1 + // z = indicator.hit.get_Vector();//3x1 + + //System.out.println(" h: r " + h.getEntry(0) + " phi " + h.getEntry(1) + " h z " + h.getEntry(2) + " z: r " + z.getEntry(0) + " phi " + z.getEntry(1) + " z " + z.getEntry(2) ); + } - RealMatrix measurementMatrixT = measurementMatrix.transpose(); // S = H * P(k) * H' + R @@ -115,11 +127,31 @@ public void correct(Indicator indicator) { // Numerically more stable !! RealMatrix tmpMatrix = identity.subtract(kalmanGain.multiply(measurementMatrix)); errorCovariance = tmpMatrix.multiply(errorCovariance.multiply(tmpMatrix.transpose())).add(kalmanGain.multiply(measurementNoise.multiply(kalmanGain.transpose()))); - + + //System.out.println(" state after: (" + stateEstimation.getEntry(0) + ", " + stateEstimation.getEntry(1) + ", " + stateEstimation.getEntry(2) + ", " + stateEstimation.getEntry(3) + ", " + stateEstimation.getEntry(4) + ", " + stateEstimation.getEntry(5) + ");" ); // Give back to the stepper the new stateEstimation stepper.y = stateEstimation.toArray(); } + public double residual(Indicator indicator) { + double d = indicator.hit.distance( new Point3D( stateEstimation.getEntry(0), stateEstimation.getEntry(1), stateEstimation.getEntry(2) ) ); + return indicator.hit.doca()-d; + } + + public double wire_sign(Indicator indicator) {//let's decide: positive when (phi state - phi wire) > 0 + double phi_state = Math.atan2(stateEstimation.getEntry(1), stateEstimation.getEntry(0)); + double phi_wire = indicator.hit.phi(stateEstimation.getEntry(2)); + //System.out.println(" phi state " + phi_state + " phi wire " + phi_wire);// + " phi state alt? " + Math.atan2(stateEstimation.getEntry(1), stateEstimation.getEntry(0))); + return (phi_state-phi_wire)/Math.abs(phi_state-phi_wire) ; + } + + // public RealMatrix wire_sign_mat(Indicator indicator) {//let's decide: positive when (phi state - phi wire) > 0 + // double phi_state = Math.atan2(stateEstimation.getEntry(1), stateEstimation.getEntry(0)); + // double phi_wire = indicator.hit.phi(stateEstimation.getEntry(2)); + // System.out.println(" phi state " + phi_state + " phi wire " + phi_wire);// + " phi state alt? " + Math.atan2(stateEstimation.getEntry(1), stateEstimation.getEntry(0))); + // return MatrixUtils.createRealMatrix(new double[][]{{(phi_state-phi_wire)/Math.abs(phi_state-phi_wire)}}); + // } + private RealMatrix F(Indicator indicator, Stepper stepper1) throws Exception { double[] dfdx = subfunctionF(indicator, stepper1, 0); @@ -157,14 +189,57 @@ private RealMatrix F(Indicator indicator, Stepper stepper1) throws Exception { return new double[]{dxdi, dydi, dzdi, dpxdi, dpydi, dpzdi}; } + //measurement matrix in cylindrical coordinates: r, phi, z private RealVector h(RealVector x, Indicator indicator) { + //As per my understanding: d -> r wire; phi -> phi wire, z unconstrained + double xx = x.getEntry(0); + double yy = x.getEntry(1); + return MatrixUtils.createRealVector(new double[]{Math.hypot(xx, yy), Math.atan2(yy, xx), x.getEntry(2)}); + } + //measurement matrix in 1 dimension: minimize distance - doca + private RealVector h_simple(RealVector x, Indicator indicator) { double d = indicator.hit.distance(new Point3D(x.getEntry(0), x.getEntry(1), x.getEntry(2))); + return MatrixUtils.createRealVector(new double[]{d});//would need to have this 3x3 + } + + //measurement noise matrix in cylindrical coordinates: r, phi, z + private RealMatrix H(RealVector x, Indicator indicator) { + // dphi/dx + double xx = x.getEntry(0); + double yy = x.getEntry(1); + + double drdx = (xx) / (Math.hypot(xx, yy)); + double drdy = (yy) / (Math.hypot(xx, yy)); + double drdz = 0.0; + double drdpx = 0.0; + double drdpy = 0.0; + double drdpz = 0.0; + + double dphidx = -(yy) / (xx * xx + yy * yy); + double dphidy = (xx) / (xx * xx + yy * yy); + double dphidz = 0.0; + double dphidpx = 0.0; + double dphidpy = 0.0; + double dphidpz = 0.0; + + double dzdx = 0.0; + double dzdy = 0.0; + double dzdz = 1.0; + double dzdpx = 0.0; + double dzdpy = 0.0; + double dzdpz = 0.0; - return MatrixUtils.createRealVector(new double[]{d}); + return MatrixUtils.createRealMatrix( + new double[][]{ + {drdx, drdy, drdz, drdpx, drdpy, drdpz}, + {dphidx, dphidy, dphidz, dphidpx, dphidpy, dphidpz}, + {dzdx, dzdy, dzdz, dzdpx, dzdpy, dzdpz} + }); } - private RealMatrix H(RealVector x, Indicator indicator) { + //measurement matrix in 1 dimension: minimize distance - doca + private RealMatrix H_simple(RealVector x, Indicator indicator) { double ddocadx = subfunctionH(x, indicator, 0); double ddocady = subfunctionH(x, indicator, 1); @@ -172,10 +247,10 @@ private RealMatrix H(RealVector x, Indicator indicator) { double ddocadpx = subfunctionH(x, indicator, 3); double ddocadpy = subfunctionH(x, indicator, 4); double ddocadpz = subfunctionH(x, indicator, 5); - - + + // As per my understanding: ddocadx,y,z -> = dr/dx,y,z, etc return MatrixUtils.createRealMatrix(new double[][]{ - {ddocadx, ddocady, ddocadz, ddocadpx, ddocadpy, ddocadpz}}); + {ddocadx, ddocady, ddocadz, ddocadpx, ddocadpy, ddocadpz}}); } double subfunctionH(RealVector x, Indicator indicator, int i) { @@ -186,8 +261,8 @@ private RealMatrix H(RealVector x, Indicator indicator) { x_plus.setEntry(i, x_plus.getEntry(i) + h); x_minus.setEntry(i, x_minus.getEntry(i) - h); - double doca_plus = h(x_plus, indicator).getEntry(0); - double doca_minus = h(x_minus, indicator).getEntry(0); + double doca_plus = h_simple(x_plus, indicator).getEntry(0); + double doca_minus = h_simple(x_minus, indicator).getEntry(0); return (doca_plus - doca_minus) / (2 * h); } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KalmanFilter.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KalmanFilter.java index e3de4248d..2cd0bc559 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KalmanFilter.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/KalmanFilter.java @@ -27,6 +27,7 @@ * - flag for target material * - error px0 use MC !! Bad !! FIX IT FAST */ +// masses/energies should be in MeV; distances should be in mm public class KalmanFilter { @@ -37,11 +38,15 @@ private void propagation(ArrayList tracks, DataEvent event) { try { //If simulation read MC::Particle Bank ------------------------------------------------ DataBank bankParticle = event.getBank("MC::Particle"); - double vzmc = bankParticle.getFloat("vz", 0); - double pxmc = bankParticle.getFloat("px", 0); - double pymc = bankParticle.getFloat("py", 0); - double pzmc = bankParticle.getFloat("pz", 0); - + double vxmc = bankParticle.getFloat("vx", 0)*10;//mm + double vymc = bankParticle.getFloat("vy", 0)*10;//mm + double vzmc = bankParticle.getFloat("vz", 0)*10;//mm + double pxmc = bankParticle.getFloat("px", 0)*1000;//MeV + double pymc = bankParticle.getFloat("py", 0)*1000;//MeV + double pzmc = bankParticle.getFloat("pz", 0)*1000;//MeV + double p_mc = java.lang.Math.sqrt(pxmc*pxmc+pymc*pymc+pzmc*pzmc); + //System.out.println("MC track: vz: " + vzmc*10 + " px: " + pxmc*1000 + " py: " + pymc*1000 + " pz: " + pzmc*1000 + "; p = " + p_mc*1000);//convert p to MeV, v to mm + ArrayList sim_hits = new ArrayList<>(); sim_hits.add(new Point3D(0, 0, vzmc)); @@ -56,6 +61,7 @@ private void propagation(ArrayList tracks, DataEvent event) { } } + /* Writer hitsWriter = new FileWriter("hits.dat"); for (Point3D p : sim_hits) { @@ -66,7 +72,7 @@ private void propagation(ArrayList tracks, DataEvent event) { // Initialization --------------------------------------------------------------------- - final double magfield = -50; + final double magfield = +50; final PDGParticle proton = PDGDatabase.getParticleById(2212); final int numberOfVariables = 6; final double tesla = 0.001; @@ -79,17 +85,24 @@ private void propagation(ArrayList tracks, DataEvent event) { final double x0 = 0.0; final double y0 = 0.0; final double z0 = tracks.get(0).get_Z0(); - final double px0 = tracks.get(0).get_px(); - final double py0 = tracks.get(0).get_py(); + //final + double px0 = tracks.get(0).get_px(); + //final + double py0 = tracks.get(0).get_py(); final double pz0 = tracks.get(0).get_pz(); + final double p_init = java.lang.Math.sqrt(px0*px0+py0*py0+pz0*pz0); double[] y = new double[]{x0, y0, z0, px0, py0, pz0}; - // System.out.println("y = " + Arrays.toString(y)); + //System.out.println("y = " + x0 + ", " + y0 + ", " + z0 + ", " + px0 + ", " + py0 + ", " + pz0 + "; p = " + p_init); + // EPAF: *the line below is for TEST ONLY!!!* + //double[] y = new double[]{vxmc, vymc, vzmc, pxmc, pymc, pzmc}; + //System.out.println("y = " + vxmc + ", " + vymc + ", " + vzmc + ", " + pxmc + ", " + pymc + ", " + pzmc + "; p = " + java.lang.Math.sqrt(pxmc*pxmc+pymc*pymc+pzmc*pzmc)); // Initialization hit - // System.out.println("tracks = " + tracks); + //System.out.println("tracks = " + tracks); ArrayList AHDC_hits = tracks.get(0).getHits(); ArrayList KF_hits = new ArrayList<>(); for (org.jlab.rec.ahdc.Hit.Hit AHDC_hit : AHDC_hits) { + //System.out.println("Superlayer = " + AHDC_hit.getSuperLayerId() + ", Layer " + AHDC_hit.getLayerId() + ", Wire " + AHDC_hit.getWireId() + ", Nwires " + AHDC_hit.getNbOfWires() + ", Radius " + AHDC_hit.getRadius() + ", DOCA " + AHDC_hit.getDoca()); Hit hit = new Hit(AHDC_hit.getSuperLayerId(), AHDC_hit.getLayerId(), AHDC_hit.getWireId(), AHDC_hit.getNbOfWires(), AHDC_hit.getRadius(), AHDC_hit.getDoca()); // Do delete hit with same radius @@ -102,7 +115,8 @@ private void propagation(ArrayList tracks, DataEvent event) { // if (!aleardyHaveR) KF_hits.add(hit); } - + + /* Writer hitsWiresWriter = new FileWriter("hits_wires.dat"); for (Hit h : KF_hits) { @@ -111,7 +125,7 @@ private void propagation(ArrayList tracks, DataEvent event) { hitsWiresWriter.close(); */ - // System.out.println("KF_hits = " + KF_hits); + //System.out.println("KF_hits = " + KF_hits); final ArrayList forwardIndicators = forwardIndicators(KF_hits, materialHashMap); final ArrayList backwardIndicators = backwardIndicators(KF_hits, materialHashMap); @@ -125,8 +139,9 @@ private void propagation(ArrayList tracks, DataEvent event) { // Initialization of the Kalman Fitter RealVector initialStateEstimate = new ArrayRealVector(stepper.y); - RealMatrix initialErrorCovariance = MatrixUtils.createRealMatrix(new double[][]{{10.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 10.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 10.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1000.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1000.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 1000.0}}); - + //first 3 lines in cm^2; last 3 lines in MeV^2 + RealMatrix initialErrorCovariance = MatrixUtils.createRealMatrix(new double[][]{{1.00, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.00, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 25.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.00, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.00, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 25.0}}); + KFitter kFitter = new KFitter(initialStateEstimate, initialErrorCovariance, stepper, propagator); /* @@ -150,30 +165,49 @@ private void propagation(ArrayList tracks, DataEvent event) { writer_back.close(); */ + //Print out hit residuals *before* fit: + // for (Indicator indicator : forwardIndicators) { + // kFitter.predict(indicator); + // if (indicator.haveAHit()) { + // System.out.println(" Pre-fit: indicator R " + indicator.R + "; y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum() + " residual: " + kFitter.residual(indicator) + " sign " + kFitter.wire_sign(indicator) ); + // } + // } + + for (int k = 0; k < 10; k++) { - for (int k = 0; k < 1; k++) { - - // System.out.println("--------- ForWard propagation !! ---------"); + //System.out.println("--------- ForWard propagation !! ---------"); for (Indicator indicator : forwardIndicators) { kFitter.predict(indicator); + //System.out.println("indicator R " + indicator.R + " h " + indicator.h + "; y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum()); if (indicator.haveAHit()) { - kFitter.correct(indicator); + //System.out.println("Superlayer = " + indicator.hit.getSuperLayer() + ", Layer " + indicator.hit.getLayer() + ", Wire " + indicator.hit.getWire() + ", Nwires " + indicator.hit.getNumWires() + ", Radius " + indicator.hit.getR() + ", DOCA " + indicator.hit.getDoca()); + kFitter.correct(indicator); + //System.out.println("y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum()); } - // System.out.println("y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum()); } - // System.out.println("--------- BackWard propagation !! ---------"); + //System.out.println("--------- BackWard propagation !! ---------"); for (Indicator indicator : backwardIndicators) { kFitter.predict(indicator); + //System.out.println("indicator R " + indicator.R + " h " + indicator.h + "; y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum()); if (indicator.haveAHit()) { - kFitter.correct(indicator); + //System.out.println("Superlayer = " + indicator.hit.getSuperLayer() + ", Layer " + indicator.hit.getLayer() + ", Wire " + indicator.hit.getWire() + ", Nwires " + indicator.hit.getNumWires() + ", Radius " + indicator.hit.getR() + ", DOCA " + indicator.hit.getDoca()); + kFitter.correct(indicator); + //System.out.println("y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum()); } - // System.out.println("y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum()); } } + // //Print out residuals *after* fit: + // for (Indicator indicator : forwardIndicators) { + // kFitter.predict(indicator); + // if (indicator.haveAHit()) { + // System.out.println(" Post-fit: indicator R " + indicator.R + "; y = " + kFitter.getStateEstimationVector() + " p = " + kFitter.getMomentum() + " residual: " + kFitter.residual(indicator) + " sign " + kFitter.wire_sign(indicator) ); + // } + // } + /* Writer writer_last = new FileWriter("track_last.dat"); for (Indicator indicator : forwardIndicators) { @@ -187,7 +221,7 @@ private void propagation(ArrayList tracks, DataEvent event) { RealVector x_out = kFitter.getStateEstimationVector(); tracks.get(0).setPositionAndMomentumForKF(x_out); - + //System.out.println("y_final = " + x_out + " p_final = " + kFitter.getMomentum()); } catch (Exception e) { // e.printStackTrace(); } @@ -200,27 +234,27 @@ private HashMap materialGeneration() { String name_De = "deuteriumGas"; double thickness_De = 1; - double density_De = 0.0009; // 9.37E-4; + double density_De = 9.37E-4;// 5.5 atm double ZoverA_De = 0.496499; - double X0_De = 0; + double X0_De = 1.3445E+5; // I guess X0 is not even used??? double IeV_De = 19.2; org.jlab.clas.tracking.kalmanfilter.Material deuteriumGas = new org.jlab.clas.tracking.kalmanfilter.Material(name_De, thickness_De, density_De, ZoverA_De, X0_De, IeV_De, units); - String name_Bo = "BONuS12Gas"; + String name_Bo = "BONuS12Gas";//80% He, 20% CO2 double thickness_Bo = 1; - double density_Bo = 4.9778E-4; - double ZoverA_Bo = 0.49989; - double X0_Bo = 0; - double IeV_Bo = 73.8871; + double density_Bo = 1.39735E-3; + double ZoverA_Bo = 0.49983; + double X0_Bo = 3.69401E+4; + double IeV_Bo = 73.5338; org.jlab.clas.tracking.kalmanfilter.Material BONuS12 = new org.jlab.clas.tracking.kalmanfilter.Material(name_Bo, thickness_Bo, density_Bo, ZoverA_Bo, X0_Bo, IeV_Bo, units); String name_My = "Mylar"; double thickness_My = 1; double density_My = 1.4; - double ZoverA_My = 0.501363; - double X0_My = 0; + double ZoverA_My = 0.52037; + double X0_My = 28.54; double IeV_My = 78.7; org.jlab.clas.tracking.kalmanfilter.Material Mylar = new org.jlab.clas.tracking.kalmanfilter.Material(name_My, thickness_My, density_My, ZoverA_My, X0_My, IeV_My, units); @@ -228,8 +262,8 @@ private HashMap materialGeneration() { String name_Ka = "Kapton"; double thickness_Ka = 1; double density_Ka = 1.42; - double ZoverA_Ka = 0.500722; - double X0_Ka = 0; + double ZoverA_Ka = 0.51264; + double X0_Ka = 28.57; double IeV_Ka = 79.6; org.jlab.clas.tracking.kalmanfilter.Material Kapton = new org.jlab.clas.tracking.kalmanfilter.Material(name_Ka, thickness_Ka, density_Ka, ZoverA_Ka, X0_Ka, IeV_Ka, units); @@ -246,8 +280,9 @@ private HashMap materialGeneration() { ArrayList forwardIndicators(ArrayList hitArrayList, HashMap materialHashMap) { ArrayList forwardIndicators = new ArrayList<>(); + //R, h, defined in mm! forwardIndicators.add(new Indicator(3.0, 0.2, null, true, materialHashMap.get("deuteriumGas"))); - forwardIndicators.add(new Indicator(3.063, 0.001, null, true, materialHashMap.get("Kapton"))); + forwardIndicators.add(new Indicator(3.060, 0.001, null, true, materialHashMap.get("Kapton"))); for (Hit hit : hitArrayList) { forwardIndicators.add(new Indicator(hit.r(), 0.1, hit, true, materialHashMap.get("BONuS12Gas"))); } @@ -256,10 +291,11 @@ ArrayList forwardIndicators(ArrayList hitArrayList, HashMap backwardIndicators(ArrayList hitArrayList, HashMap materialHashMap) { ArrayList backwardIndicators = new ArrayList<>(); + //R, h, defined in mm! for (int i = hitArrayList.size() - 2; i >= 0; i--) { backwardIndicators.add(new Indicator(hitArrayList.get(i).r(), 0.1, hitArrayList.get(i), false, materialHashMap.get("BONuS12Gas"))); } - backwardIndicators.add(new Indicator(3.063, 1, null, false, materialHashMap.get("BONuS12Gas"))); + backwardIndicators.add(new Indicator(3.060, 1, null, false, materialHashMap.get("BONuS12Gas"))); backwardIndicators.add(new Indicator(3.0, 0.001, null, false, materialHashMap.get("Kapton"))); Hit hit = new Hit_beam(0, 0, 0, 0, 0, 0, 0, 0); backwardIndicators.add(new Indicator(0.0, 0.2, hit, false, materialHashMap.get("deuteriumGas"))); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/MaterialMap.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/MaterialMap.java index ac3e5eee3..cfc0ff649 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/MaterialMap.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/MaterialMap.java @@ -14,25 +14,25 @@ public static HashMap generateMaterials() { double thickness_De = 1; double density_De = 9.37E-4; double ZoverA_De = 0.496499; - double X0_De = 0; + double X0_De = 1.3445E+5; double IeV_De = 19.2; org.jlab.clas.tracking.kalmanfilter.Material deuteriumGas = new org.jlab.clas.tracking.kalmanfilter.Material(name_De, thickness_De, density_De, ZoverA_De, X0_De, IeV_De, units); - String name_Bo = "BONuS12Gas"; + String name_Bo = "BONuS12Gas";//80% He, 20% CO2 double thickness_Bo = 1; - double density_Bo = 4.9778E-4; - double ZoverA_Bo = 0.49989; - double X0_Bo = 0; - double IeV_Bo = 73.8871; + double density_Bo = 1.39735E-3; + double ZoverA_Bo = 0.49983; + double X0_Bo = 3.69401E+4; + double IeV_Bo = 73.5338; org.jlab.clas.tracking.kalmanfilter.Material BONuS12 = new org.jlab.clas.tracking.kalmanfilter.Material(name_Bo, thickness_Bo, density_Bo, ZoverA_Bo, X0_Bo, IeV_Bo, units); String name_My = "Mylar"; double thickness_My = 1; double density_My = 1.4; - double ZoverA_My = 0.501363; - double X0_My = 0; + double ZoverA_My = 0.52037; + double X0_My = 28.54; double IeV_My = 78.7; org.jlab.clas.tracking.kalmanfilter.Material Mylar = new org.jlab.clas.tracking.kalmanfilter.Material(name_My, thickness_My, density_My, ZoverA_My, X0_My, IeV_My, units); @@ -40,8 +40,8 @@ public static HashMap generateMaterials() { String name_Ka = "Kapton"; double thickness_Ka = 1; double density_Ka = 1.42; - double ZoverA_Ka = 0.500722; - double X0_Ka = 0; + double ZoverA_Ka = 0.51264; + double X0_Ka = 28.57; double IeV_Ka = 79.6; org.jlab.clas.tracking.kalmanfilter.Material Kapton = new org.jlab.clas.tracking.kalmanfilter.Material(name_Ka, thickness_Ka, density_Ka, ZoverA_Ka, X0_Ka, IeV_Ka, units); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Propagator.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Propagator.java index b3757f8c3..8064c250d 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Propagator.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/Propagator.java @@ -9,6 +9,8 @@ import java.io.Writer; import java.util.Arrays; +// All distances here should be in mm. +// Do all those hardcoded values even make sense??? public class Propagator { private final RungeKutta4 RK4; diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/RungeKutta4.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/RungeKutta4.java index e4fe87a6b..8830b959e 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/RungeKutta4.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/KalmanFilter/RungeKutta4.java @@ -51,7 +51,8 @@ public void doOneStep(Stepper stepper) { else stepper.sTot -= h; System.arraycopy(yIn, 0, yInTemp, 0, numberOfVariables); - + // compare below with: + // common-tools/clas-tracking/src/main/java/org/jlab/clas/tracking/utilities/RungeKuttaDoca.java => reconciled? double[] dydt = f(yInTemp); for (int i = 0; i < numberOfVariables; ++i) { k1[i] = h * dydt[i]; @@ -89,8 +90,9 @@ public void doOneStep(Stepper stepper) { + 1.0 / 3.0 * k3[i] + 1.0 / 6.0 * k4[i]; } - + //System.out.print("before: px "+yInTemp[3]+" py "+yInTemp[4]+" pz "+yInTemp[5]+"; h = "+h+"; after : "+yIn[3]+" py "+yIn[4]+" pz "+yIn[5]); energyLoss(yIn, h, material_); + //System.out.println("; after Eloss: "+yIn[3]+" py "+yIn[4]+" pz "+yIn[5]); } private double[] f(double[] y) { @@ -119,19 +121,21 @@ private double[] f(double[] y) { } } + // This uses MeV and cm private void energyLoss( double[] yIn, double h, org.jlab.clas.tracking.kalmanfilter.Material material) { - double mass = particle.mass() * 1000; + double mass = particle.mass() * 1000; //particle mass defined in GeV, converted to MeV - h /= 10; // cm + h /= 10; // h defined in mm, converted to cm double mom = Math.sqrt(yIn[3] * yIn[3] + yIn[4] * yIn[4] + yIn[5] * yIn[5]); double E = Math.sqrt(mom * mom + mass * mass); - - double dedx = material.getEloss(mom/1000, mass/1000) * 1000; + //material::getEloss(double p, double m) uses GeV and cm + //see common-tools/clas-tracking/src/main/java/org/jlab/clas/tracking/kalmanfilter/Material.java + double dedx = material.getEloss(mom/1000, mass/1000) * 1000;//Momentum, mass input in GeV, output in GeV/cm, converted to MeV/cm double DeltaE = dedx * h; stepper.dEdx += DeltaE; - + double mom_prim; if (this.stepper.direction) mom_prim = Math.sqrt((E - DeltaE) * (E - DeltaE) - mass * mass); else mom_prim = Math.sqrt((E + DeltaE) * (E + DeltaE) - mass * mass); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/PreCluster/PreClusterFinder.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/PreCluster/PreClusterFinder.java index e28a16b3a..ab5b69260 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/PreCluster/PreClusterFinder.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/PreCluster/PreClusterFinder.java @@ -12,42 +12,25 @@ public class PreClusterFinder { public PreClusterFinder() { _AHDCPreClusters = new ArrayList<>(); } - - private void fill_list(List AHDC_hits, ArrayList sxlx, int super_layer, int layer) { - for (Hit hit : AHDC_hits) { - if (hit.getSuperLayerId() == super_layer && hit.getLayerId() == layer) { - sxlx.add(hit); + + private void fill_list(List AHDC_hits, ArrayList> all_super_layer){ + int nsuper_layers = 8; + int super_layers[] = {1,2,2,3,3,4,4,5}; + int layers[] = {1,1,2,1,2,1,2,1}; + for(int i = 0; i < nsuper_layers; i++){ + ArrayList sxlx = new ArrayList<>(); + for (Hit hit : AHDC_hits) { + if (hit.getSuperLayerId() == super_layers[i] && hit.getLayerId() == layers[i]) { + sxlx.add(hit); + } } + all_super_layer.add(sxlx); } } public void findPreCluster(List AHDC_hits) { - ArrayList s0l0 = new ArrayList(); - fill_list(AHDC_hits, s0l0, 0, 0); - ArrayList s1l0 = new ArrayList(); - fill_list(AHDC_hits, s1l0, 1, 0); - ArrayList s1l1 = new ArrayList(); - fill_list(AHDC_hits, s1l1, 1, 1); - ArrayList s2l0 = new ArrayList(); - fill_list(AHDC_hits, s2l0, 2, 0); - ArrayList s2l1 = new ArrayList(); - fill_list(AHDC_hits, s2l1, 2, 1); - ArrayList s3l0 = new ArrayList(); - fill_list(AHDC_hits, s3l0, 3, 0); - ArrayList s3l1 = new ArrayList(); - fill_list(AHDC_hits, s3l1, 3, 1); - ArrayList s4l0 = new ArrayList(); - fill_list(AHDC_hits, s4l0, 4, 0); - ArrayList> all_super_layer = new ArrayList<>(); - all_super_layer.add(s0l0); - all_super_layer.add(s1l0); - all_super_layer.add(s1l1); - all_super_layer.add(s2l0); - all_super_layer.add(s2l1); - all_super_layer.add(s3l0); - all_super_layer.add(s3l1); - all_super_layer.add(s4l0); + fill_list(AHDC_hits,all_super_layer); for (ArrayList sxlx : all_super_layer) { for (Hit hit : sxlx) { @@ -57,10 +40,10 @@ public void findPreCluster(List AHDC_hits) { hit.setUse(true); int expected_wire_plus = hit.getWireId() + 1; int expected_wire_minus = hit.getWireId() - 1; - if (hit.getWireId() - 1 == 0) { + if (hit.getWireId() == 1) { expected_wire_minus = hit.getNbOfWires(); } - if (hit.getWireId() + 1 == hit.getNbOfWires() + 1) { + if (hit.getWireId() == hit.getNbOfWires() ) { expected_wire_plus = 1; } diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java index aa11e12d4..10d2351ab 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/ahdc/Track/Track.java @@ -40,6 +40,21 @@ public Track(List clusters) { generateHitList(); } + public Track(ArrayList hitslist) { + hits.addAll(hitslist); + this.x0 = 0.0; + this.y0 = 0.0; + this.z0 = 0.0; + double p = 150.0;//MeV/c + //take first hit. + Hit hit = hitslist.get(0); + double phi = Math.atan2(hit.getX(), hit.getY()); + //hitslist. + this.px0 = p*Math.sin(phi); + this.py0 = p*Math.cos(phi); + this.pz0 = 0.0; + } + public void setPositionAndMomentum(HelixFitObject helixFitObject) { this.x0 = helixFitObject.get_X0(); this.y0 = helixFitObject.get_Y0(); diff --git a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java index 2bed94d9f..ed27d104e 100644 --- a/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java +++ b/reconstruction/alert/src/main/java/org/jlab/rec/service/AHDCEngine.java @@ -1,11 +1,29 @@ package org.jlab.rec.service; +import ai.djl.MalformedModelException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; import org.jlab.clas.reco.ReconstructionEngine; import org.jlab.clas.tracking.kalmanfilter.Material; import org.jlab.io.base.DataBank; import org.jlab.io.base.DataEvent; import org.jlab.io.hipo.HipoDataSource; import org.jlab.io.hipo.HipoDataSync; +import org.jlab.jnp.hipo4.data.SchemaFactory; +import org.jlab.rec.ahdc.AI.AIPrediction; +import org.jlab.rec.ahdc.AI.PreClustering; +import org.jlab.rec.ahdc.AI.PreclusterSuperlayer; +import org.jlab.rec.ahdc.AI.TrackConstruction; +import org.jlab.rec.ahdc.AI.TrackPrediction; import org.jlab.rec.ahdc.Banks.RecoBankWriter; import org.jlab.rec.ahdc.Cluster.Cluster; import org.jlab.rec.ahdc.Cluster.ClusterFinder; @@ -22,14 +40,17 @@ import org.jlab.rec.ahdc.Track.Track; import java.io.File; -import java.util.ArrayList; -import java.util.HashMap; +import java.io.IOException; +import java.nio.file.Paths; +import java.util.*; public class AHDCEngine extends ReconstructionEngine { private boolean simulation; + private boolean use_AI_for_trackfinding; private String findingMethod; private HashMap materialMap; + private ZooModel model; public AHDCEngine() { super("ALERT", "ouillon", "1.0.1"); @@ -39,11 +60,42 @@ public AHDCEngine() { public boolean init() { simulation = false; findingMethod = "distance"; + use_AI_for_trackfinding = true; if (materialMap == null) { materialMap = MaterialMap.generateMaterials(); } + Translator my_translator = new Translator() { + @Override + public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception { + return ndList.get(0).getFloat(); + } + + @Override + public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception { + NDManager manager = NDManager.newBaseManager(); + NDArray samples = manager.zeros(new Shape(floats.length)); + samples.set(floats); + return new NDList(samples); + } + }; + + Criteria my_model = Criteria.builder().setTypes(float[].class, Float.class) + .optModelPath(Paths.get(System.getenv("CLAS12DIR") + "/../reconstruction/alert/src/main/java/org/jlab/rec/ahdc/AI/model/")) + .optEngine("PyTorch") + .optTranslator(my_translator) + .optProgress(new ProgressBar()) + .build(); + + + try { + model = my_model.loadModel(); + } catch (IOException | ModelNotFoundException | MalformedModelException e) { + throw new RuntimeException(e); + } + + return true; } @@ -69,38 +121,80 @@ public boolean processDataEvent(DataEvent event) { magfield = 50 * magfieldfactor; - //if (event.hasBank("AHDC::tdc")) { if (event.hasBank("AHDC::adc")) { - // I) Read raw hit HitReader hitRead = new HitReader(event, simulation); ArrayList AHDC_Hits = hitRead.get_AHDCHits(); ArrayList TrueAHDC_Hits = hitRead.get_TrueAHDCHits(); - + //System.out.println("AHDC_Hits size " + AHDC_Hits.size()); + // II) Create PreCluster + ArrayList AHDC_PreClusters = new ArrayList<>(); PreClusterFinder preclusterfinder = new PreClusterFinder(); preclusterfinder.findPreCluster(AHDC_Hits); - ArrayList AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); + AHDC_PreClusters = preclusterfinder.get_AHDCPreClusters(); + //System.out.println("AHDC_PreClusters size " + AHDC_PreClusters.size()); + + // III) Create Cluster ClusterFinder clusterfinder = new ClusterFinder(); clusterfinder.findCluster(AHDC_PreClusters); ArrayList AHDC_Clusters = clusterfinder.get_AHDCClusters(); - + //System.out.println("AHDC_Clusters size " + AHDC_Clusters.size()); + // IV) Track Finder ArrayList AHDC_Tracks = new ArrayList<>(); - if (findingMethod.equals("distance")) { - // IV) a) Distance method - Distance distance = new Distance(); - distance.find_track(AHDC_Clusters); - AHDC_Tracks = distance.get_AHDCTracks(); - } else if (findingMethod.equals("hough")) { - // IV) b) Hough Transform method - HoughTransform houghtransform = new HoughTransform(); - houghtransform.find_tracks(AHDC_Clusters); - AHDC_Tracks = houghtransform.get_AHDCTracks(); + ArrayList predictions = new ArrayList<>(); + + if (use_AI_for_trackfinding == false) { + if (findingMethod.equals("distance")) { + // IV) a) Distance method + //System.out.println("using distance"); + Distance distance = new Distance(); + distance.find_track(AHDC_Clusters); + AHDC_Tracks = distance.get_AHDCTracks(); + } else if (findingMethod.equals("hough")) { + // IV) b) Hough Transform method + //System.out.println("using hough"); + HoughTransform houghtransform = new HoughTransform(); + houghtransform.find_tracks(AHDC_Clusters); + AHDC_Tracks = houghtransform.get_AHDCTracks(); + } + } + else { + // AI --------------------------------------------------------------------------------- + AHDC_Hits.sort(new Comparator() { + @Override + public int compare(Hit a1, Hit a2) { + return Double.compare(a1.getRadius(), a2.getRadius()); + } + }); + PreClustering preClustering = new PreClustering(); + ArrayList preClustersAI = preClustering.find_preclusters_for_AI(AHDC_Hits); + ArrayList preclusterSuperlayers = preClustering.merge_preclusters(preClustersAI); + TrackConstruction trackConstruction = new TrackConstruction(); + ArrayList> tracks = trackConstruction.get_all_possible_track(preclusterSuperlayers); + + + try { + AIPrediction aiPrediction = new AIPrediction(); + predictions = aiPrediction.prediction(tracks, model); + } catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) { + throw new RuntimeException(e); + } + + for (TrackPrediction t : predictions) { + if (t.getPrediction() > 0.5) + AHDC_Tracks.add(new Track(t.getClusters())); + } } + // ------------------------------------------------------------------------------------ + + + //Temporary track method ONLY for MC with no background; + //AHDC_Tracks.add(new Track(AHDC_Hits)); // V) Global fit for (Track track : AHDC_Tracks) { @@ -132,12 +226,14 @@ public boolean processDataEvent(DataEvent event) { DataBank recoClusterBank = writer.fillClustersBank(event, AHDC_Clusters); DataBank recoTracksBank = writer.fillAHDCTrackBank(event, AHDC_Tracks); DataBank recoKFTracksBank = writer.fillAHDCKFTrackBank(event, AHDC_Tracks); + DataBank AIPredictionBanks = writer.fillAIPrediction(event, predictions); event.appendBank(recoHitsBank); event.appendBank(recoPreClusterBank); event.appendBank(recoClusterBank); event.appendBank(recoTracksBank); event.appendBank(recoKFTracksBank); + event.appendBank(AIPredictionBanks); if (simulation) { DataBank recoMCBank = writer.fillAHDCMCTrackBank(event);